From 8844691f4bf1e44304a3bbf1eac86cc4b11d0dbe Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 12 Sep 2023 15:14:24 +0800 Subject: [PATCH 01/58] [shardformer] update shardformer readme (#4689) * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 147 ++++++++++-------- .../examples/convergence_benchmark.py | 7 +- .../examples/convergence_benchmark.sh | 2 +- .../examples/performance_benchmark.py | 6 +- 4 files changed, 90 insertions(+), 72 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 2e48a79dc1d7..559f9a56f61e 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,27 +30,48 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization): ```python -from colossalai.shardformer import ShardConfig, Shard +from colossalai.shardformer import ShardConfig, ShardFormer from transformers import BertForMaskedLM +import colossalai # launch colossalai -colossalai.launch_from_torch() +colossalai.launch_from_torch(config={}) # create model config = BertConfig.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) # create huggingface model as normal -shard_config = ShardConfig() +shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=True, + enable_fused_normalization=True, + enable_flash_attention=True, + enable_jit_fused=True, + enable_sequence_parallelism=True, + enable_sequence_overlap=True) + shard_former = ShardFormer(shard_config=shard_config) -sharded_model = shard_former.optimize(model).to('cuda') +sharded_model, shared_params = shard_former.optimize(model).to('cuda') # do everything like normal ... ``` +shardformer configuration + +`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. +`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. +{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} +`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows +`enable_fused_normalization`: using apex fused layernorm +`enable_flash_attention`: using flash attention +`enable_jit_fused`: using jit fused operators +`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension. +`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`. + ### Write your own policy @@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer: - [x] API Implementation - [x] Unit Testing - [ ] Policy Implementation - - [ ] Hugging Face - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [ ] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J - - [ ] CV - - [x] ViT - - [ ] BEiT - - [ ] SwinTransformer - - [ ] SwinTransformer V2 - - [ ] Audio - - [x] Whisper - - [ ] Multi-modal - - [x] SAM - - [x] BLIP-2 -- [ ] Flash Attention Support - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [ ] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J + +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | +| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | + ## 💡 API Design @@ -286,41 +293,36 @@ class ShardFormer: Example: + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - model = shard_former.optimize(model, policy=policy) - dataloader = shard_former.shard_dataset(dataset) + model, shared_params = shard_former.optimize(org_model) """ def __init__(self, shard_config: ShardConfig): """ Do two things: - 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 1. Create a distribute coordinator 2. serve as a store for shard config """ self.shard_config = shard_config - self.pg_manager = None + self.coordinator = DistCoordinator() - def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: - """ - Initialize the distributed process group according to the - """ - pg_manager = ... - self.pg_manager = pg_manager - return pg_manager + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: + r""" + This method will optimize the model based on the given policy. - def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module: - """ - Shard model for TP and PP - """ - ... + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding - def shard_dataset(self, dataset: Dataset) -> Dataloader: + Returns: the sharded model and the shared parameters """ - Shard dataset for DP - """ - ... + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + shared_params = sharder.shard() + return model, shared_params ``` ## ⌨️ Development Notes @@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. + +the configurations are as follows: +```python +batch_size = 2 +epoch = 3 +lr = 2.4e-5 +accumulation_steps = 8 +warmup_fraction = 0.03 +``` + | accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.84589 | 0.88613 | 0.43414 | 4 | True | -| 0.83594 | 0.88064 | 0.43298 | 1 | False | +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index de82305b2547..81be2017855c 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -49,9 +49,12 @@ def train(args): # if multiple GPUs, shard the model if dist.get_world_size() > 1: - shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) + tp_group = dist.new_group(backend='nccl') + shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True) shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(model) + model, _ = shard_former.optimize(model) optim = Adam(model.parameters(), lr=args.lr) num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh index 1c281abcda6d..22f13a7cf827 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,7 +1,7 @@ torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ - --max_epochs 1 \ + --max_epochs 3 \ --batch_size 2 \ --lr 2.4e-5 \ --fused_layernorm False \ diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 9c7b76bcf0a6..2f186709d946 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -29,7 +29,8 @@ def data_gen_for_sequence_classification(batch_size, seq_length): intermediate_size=256, num_attention_heads=4, max_position_embeddings=128, - num_labels=16) + num_labels=16, + pad_token_id=2) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) @@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d if provider == "shard_model": shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model).cuda() + sharded_model, _ = shard_former.optimize(model) + sharded_model = sharded_model.cuda() fn = lambda: train(sharded_model, data) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms From d8ceeac14e54c5c568e916c061b86d9a53a54f30 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 12 Sep 2023 17:32:19 +0800 Subject: [PATCH 02/58] [hotfix] fix typo in hybrid parallel io (#4697) --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++-- colossalai/checkpoint_io/__init__.py | 2 +- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 125a9ccca1b5..fc04f3ecd8e7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -16,7 +16,7 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule @@ -513,7 +513,7 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 07b1f81dace6..e1aa6543ef39 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,6 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO -from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO +from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index fef5b0d16d60..6eee3ace0308 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -39,7 +39,7 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state' -class HypridParallelCheckpointIO(GeneralCheckpointIO): +class HybridParallelCheckpointIO(GeneralCheckpointIO): """ CheckpointIO for Hybrid Parallel Training. @@ -136,7 +136,7 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, param_id = param_info['param2id'][id(working_param)] original_shape = param_info['param2shape'][id(working_param)] - state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, working_param, original_shape=original_shape, dp_group=dp_group, @@ -189,7 +189,7 @@ def save_sharded_model(self, # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) control_saving = (self.tp_rank == 0) @@ -385,7 +385,7 @@ def save_sharded_optimizer(self, # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, dp_group=self.dp_group, From 9c2feb2f0bb92cb5f1b4fb967447aa91a7c7beb5 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 12 Sep 2023 17:41:52 +0800 Subject: [PATCH 03/58] fix some typo with colossalai/device colossalai/tensor/ etc. (#4171) Co-authored-by: flybird11111 <1829166702@qq.com> --- colossalai/device/device_mesh.py | 12 ++++++------ colossalai/tensor/d_tensor/comm_spec.py | 2 +- colossalai/tensor/shape_consistency.py | 4 ++-- tests/kit/model_zoo/transformers/t5.py | 2 +- .../test_plugin/test_torch_ddp_plugin.py | 2 +- .../test_plugin/test_torch_fsdp_plugin.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 267c4529eb95..f41af1161be1 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -59,7 +59,7 @@ def __init__(self, # 2. directly supply the logical mesh id assert mesh_shape is None or logical_mesh_id is None, \ "Only one of mesh_shape and logical_mesh_id can be specified." \ - "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" + "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id" if logical_mesh_id is None: self._mesh_shape = mesh_shape @@ -74,7 +74,7 @@ def __init__(self, assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ - "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." + "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." @@ -118,7 +118,7 @@ def __init__(self, self._global_rank_of_current_process = None self._is_initialized = False - # attribute used to inidicate whether this objectd + # attribute used to indicate whether this object # is created using DeviceMesh.from_process_group # this attribute can be used to do some check in methods # such get_process_group as no global rank information @@ -395,7 +395,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): Example: ```python - sphysical_mesh_id = torch.arange(0, 16) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # logical mesh will look like @@ -438,7 +438,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): # the _local_rank refers to the local rank of the current process for _local_rank in range(self.logical_mesh_id.shape[dim]): - # if this dimension is not initailized yet, + # if this dimension is not initialized yet, # initialize it with an empty array if dim not in processes_in_the_same_process_group: processes_in_the_same_process_group[dim] = [] @@ -447,7 +447,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() # replace the local rank in the given dimension with the - # lcoal rank of the current process iterated + # local rank of the current process iterated process_coordinates[dim] = _local_rank processes_in_the_same_process_group[dim].append(process_coordinates) diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 79b2e3ef936a..6158d0bfe2ad 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -28,7 +28,7 @@ class CommSpec: to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 99d782c3f6e8..b837333a2388 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -339,7 +339,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, RS01 -> RR ''' valid_spec_dict = {} - comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): @@ -362,7 +362,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, b_target_pair = (b_index, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - comm_spec = CommSpec(comm_pathern, + comm_spec = CommSpec(comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, logical_process_axis=logical_process_axes, diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 175d48963480..16a594f3950a 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -43,7 +43,7 @@ def data_gen_for_t5_model(): # output transform function output_transform_fn = lambda x: x -# define loss funciton +# define loss function loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() loss_fn_for_conditional_generation = lambda x: x.loss diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 1484273973ae..23d743c924aa 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -64,7 +64,7 @@ def check_torch_ddp_no_sync(): model = DummyModel() criterion = lambda x: x.mean() optimizer = SGD(model.parameters(), lr=1e-3) - # create a custom dasetset with 0 to 10 + # create a custom dataset with 0 to 10 dataset = torch.arange(0, 10) train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) model, optimizer, criterion, train_dataloader, _ = booster.boost(model, diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index cbd5d57800db..e09ad766bb32 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -15,7 +15,7 @@ from tests.kit.model_zoo import model_zoo -# test baisc fsdp function +# test basic fsdp function def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) From 068372a73889e5c872a8185d73302e9e3f77b750 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 13 Sep 2023 10:43:30 +0800 Subject: [PATCH 04/58] [doc] add potential solution for OOM in llama2 example (#4699) --- examples/language/llama2/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 483eae88ae32..16b263c1322e 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -149,6 +149,9 @@ Finally, run the following command to start training: ```bash bash gemini.sh ``` + +If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU. + #### c. Results If you run the above command successfully, you will get the following results: `max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`. From c7d6975d2984825df40bb86ac39fc1c3d137fe96 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 13 Sep 2023 15:57:16 +0800 Subject: [PATCH 05/58] [shardformer] fix GPT2DoubleHeadsModel (#4703) --- colossalai/shardformer/modeling/gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bc99be4cc391..84deafefeadd 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -94,9 +94,9 @@ def gpt2_model_forward( if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] - batch_size = input_shape[0] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] # GPT2Attention mask. if attention_mask is not None: From e2c0e7f92abf6b93ccb331298880a41553df0cc7 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Thu, 14 Sep 2023 18:03:55 +0800 Subject: [PATCH 06/58] [hotfix] Fix import error: colossal.kernel without triton installed (#4722) * [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for infer --- .../tensor_parallel/modeling/bloom.py | 4 +-- .../tensor_parallel/modeling/llama.py | 10 ++++--- .../tensor_parallel/policies/bloom.py | 4 +-- .../tensor_parallel/policies/llama.py | 26 +++++++++---------- colossalai/kernel/__init__.py | 7 ----- colossalai/kernel/triton/__init__.py | 7 +++++ 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 9768fc425628..ba5eadc92be8 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -17,9 +17,7 @@ from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd def generate_alibi(n_head, dtype=torch.float16): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 219cd1ae0d0e..07b73a6f4ca6 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -6,10 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + llama_context_attn_fwd, + rotary_embedding_fwd, + token_attention_fwd, +) try: from vllm import layernorm_ops, pos_encoding_ops diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 63791fe27284..cae43aa20421 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -8,10 +8,10 @@ from ..modeling.bloom import BloomInferenceForwards try: - from colossalai.kernel.triton.fused_layernorm import layer_norm + from colossalai.kernel.triton import layer_norm HAS_TRITON_NORM = True except: - print("you should install triton from https://github.com/openai/triton") + print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") HAS_TRITON_NORM = False diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index e819f2a8810c..4844415d612c 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,33 +1,32 @@ from functools import partial + import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - LlamaRMSNorm -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton.rms_norm import rmsnorm_forward + from colossalai.kernel.triton import rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") HAS_TRITON_RMSNORM = False - + def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - + return _triton_rmsnorm_forward else: return None - + + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -59,12 +58,11 @@ def module_policy(self): else: # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 infer_forward = get_llama_vllm_rmsnorm_forward() - + if infer_forward is not None: method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + policy=policy, + target_key=LlamaRMSNorm) return policy - diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index a99cb497c3e7..8933fc0a3c2f 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,14 +1,7 @@ from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -from .triton import llama_context_attn_fwd, bloom_context_attn_fwd -from .triton import softmax -from .triton import copy_kv_cache_to_dest __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", - "llama_context_attn_fwd", - "bloom_context_attn_fwd", - "softmax", - "copy_kv_cache_to_dest", ] diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index eb0335c01ce2..5840ad2918be 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,4 +2,11 @@ from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .rms_norm import rmsnorm_forward +from .rotary_embedding_kernel import rotary_embedding_fwd from .softmax import softmax +from .token_attention_kernel import token_attention_fwd + +__all__ = [ + "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", + "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" +] From 20190b49a5f79b065b820ef84b41da8044e76c39 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 14 Sep 2023 21:34:20 +0800 Subject: [PATCH 07/58] [shardformer] to fix whisper test failed due to significant accuracy differences. (#4710) * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed --- colossalai/shardformer/README.md | 2 +- colossalai/shardformer/policies/whisper.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 559f9a56f61e..b1573ae163a0 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer: | bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | -| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5d496f08e1db..31ba82166b31 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -57,6 +57,11 @@ def module_policy(self): warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + #TODO using the jit fused add_and_dropout affect the accuracy + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": From ce97790ed73f7962ab1ceae057a020168b45dda4 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Thu, 14 Sep 2023 23:19:25 +0800 Subject: [PATCH 08/58] [doc] fix llama2 code link (#4726) * [doc] fix llama2 code link * [doc] fix llama2 code link * [doc] fix llama2 code link --- README.md | 2 +- docs/README-zh-Hans.md | 2 +- examples/language/llama2/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0ddcdab741a4..25d3b8f83f1e 100644 --- a/README.md +++ b/README.md @@ -224,7 +224,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

- 70 billion parameter LLaMA2 model training accelerated by 195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1 diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index dda4f86a29a0..41eebc59c493 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -217,7 +217,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

- 700亿参数LLaMA2训练加速195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1 diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 16b263c1322e..c8fc86d29d97 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -6,7 +6,7 @@

- 70 billion parameter LLaMA2 model training accelerated by 195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1 From f911d5b09dc8be6c444da1dbc1fd605aa69007b6 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 15 Sep 2023 10:56:39 +0800 Subject: [PATCH 09/58] [doc] Add user document for Shardformer (#4702) * create shardformer doc files * add docstring for seq-parallel * update ShardConfig docstring * add links to llama example * add outdated massage * finish introduction & supporting information * finish 'how shardformer works' * finish shardformer.md English doc * fix doctest fail * add Chinese document --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +- colossalai/shardformer/README.md | 32 ++-- colossalai/shardformer/shard/shard_config.py | 26 ++-- docs/source/en/basics/booster_api.md | 3 +- docs/source/en/basics/booster_plugins.md | 2 +- docs/source/en/features/1D_tensor_parallel.md | 4 + docs/source/en/features/shardformer.md | 143 ++++++++++++++++++ docs/source/zh-Hans/basics/booster_api.md | 5 +- docs/source/zh-Hans/basics/booster_plugins.md | 2 +- .../zh-Hans/features/1D_tensor_parallel.md | 4 + docs/source/zh-Hans/features/shardformer.md | 121 +++++++++++++++ 11 files changed, 316 insertions(+), 34 deletions(-) create mode 100644 docs/source/en/features/shardformer.md create mode 100644 docs/source/zh-Hans/features/shardformer.md diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fc04f3ecd8e7..3fbeebcc4110 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -243,9 +243,11 @@ class HybridParallelPlugin(PipelinePluginBase): enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. Currently all the optimization methods include fused normalization, flash attention and JIT. Defaults to False. - enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False. - enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. - enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index b1573ae163a0..4bd7d5208a64 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -60,18 +60,28 @@ sharded_model, shared_params = shard_former.optimize(model).to('cuda') # do everything like normal ... ``` -shardformer configuration - -`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. -`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. -{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} -`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows -`enable_fused_normalization`: using apex fused layernorm -`enable_flash_attention`: using flash attention -`enable_jit_fused`: using jit fused operators -`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension. -`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`. +Following are the description `ShardConfig`'s arguments: + +- `tensor_parallel_process_group`: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. + +- `pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + +- `enable_tensor_parallelism`: Whether to use tensor parallelism. Defaults to True. + +- `enable_fused_normalization`: Whether to use fused layernorm. Defaults to False. + +- `enable_flash_attention`: Whether to switch on flash attention. Defaults to False. + +- `enable_jit_fused`: Whether to switch on JIT fused operators. Defaults to False. + +- `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. + +- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. + +- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. + +- `inference_only`: Whether only doing forward passing. Defaults to False. ### Write your own policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4380ac30814d..0b6e1640952b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -15,32 +15,28 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group. - pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. - enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_all_optimization (bool): Whether to turn on all optimization, default is False. - enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False. - enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False. + tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. + enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + inference_only (bool): Whether only doing forward passing. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False - enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + enable_all_optimization: bool = False inference_only: bool = False - enable_sequence_parallelism: bool = False - enable_sequence_overlap: bool = False - - # pipeline_parallel_size: int - # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] - # inference_only: bool = True - # gather_output: bool = True @property def tensor_parallel_size(self): diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 7962707514de..392251ef06b2 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -9,7 +9,8 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https: **Example Code** -- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [Train ResNet on CIFAR-10 with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [Train LLaMA-1/2 on RedPajama with Booster](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) ## Introduction diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 7a88dc1701ba..d7532b0ce39b 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -73,7 +73,7 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.h This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts: -1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. +1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md). 2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md). diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md index 7157af210bc5..79fe5ddea221 100644 --- a/docs/source/en/features/1D_tensor_parallel.md +++ b/docs/source/en/features/1D_tensor_parallel.md @@ -2,6 +2,8 @@ Author: Zhengda Bian, Yongbin Li +> ⚠️ The information on this page is outdated and will be deprecated. Please check [Shardformer](./shardformer.md) for more information. + **Prerequisite** - [Define Your Configuration](../basics/define_your_config.md) - [Configure Parallelization](../basics/configure_parallelization.md) @@ -116,3 +118,5 @@ Output of the first linear layer: torch.Size([16, 512]) Output of the second linear layer: torch.Size([16, 256]) ``` The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs. + + diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md new file mode 100644 index 000000000000..872d00e4a073 --- /dev/null +++ b/docs/source/en/features/shardformer.md @@ -0,0 +1,143 @@ +# Shardformer + +Author: [Baizhou Zhang](https://github.com/Fridge003) + +**Prerequisite** +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Booster Plugins](../basics/booster_plugins.md) + +**Example Code** +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) +- [Enabling Shardformer using HybridPrallelPlugin](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) + +**Related Paper** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) +- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) + + +## Introduction + +When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallism, are essential so as to meet the limitation of GPU memory. +However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training. +Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library. + +Out of this motivation, the ColossalAI team develops **Shardformer**, a feature that automatically does preparation of model parallelism (tensor parallelism/pipeline parallelism) for popular transformer models in HuggingFace. +This module aims to make parallelization hassle-free for users who are not from the system background. +Within a few lines of codes, users can turn a model into a state ready for distributed training. +Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass. + + +## How Shardformer Works + +Generally, Shardformer works through the following four kinds of *replacements*: + +1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. +Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. +For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. + +3. Replacing the `forward` methods implemented by original Huggingface +Transformers libraries with our customized `forward` methods. +This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. +Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. + +4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). +By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. +To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. +All other parameters are released so as to liberate memory usage. +As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. + +All of these replacements are implemented with manually written policies and forward functions. +If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. + +## Usage + +### Shardformer Configuration + +The configuration of Shardformer is controlled by class `ShardConfig`: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +If you want to enable Apex Fused Layernorm, please install `apex`. +If you want to enable the usage of flash attention, please install `flash_attn`. +In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. + +### Enabling Shardformer + +#### 1. Enabling Shardformer Through Booster (Recommended) + +Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. +The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. + +More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md). + +[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline. + + +#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) + +You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. + +[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +is an example on how to trigger `Shardformer` through calling Shardformer APIs. + + +### Precautions + +1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method. + +2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. + +3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. + + +## Supporting Information + +List of Huggingface transformers model families currently supported by Shardformer: +- LlaMa-1/LlaMa-2 +- GPT2 +- BERT +- OPT +- BLOOM +- T5 +- ViT +- ChatGLM-2 6B +- Whisper + +List of optimization tools currently supported by Shardformer: +- Flash Attention 2 +- JIT Fused Operator +- xFormers +- Fused Layer Normalization +- Sequence Parallel +- Sequence Overlap + +List of model families we plan to support in the near future: +- SAM +- Blip2 +- RoBERTa +- ALBERT +- ERNIE +- GPT Neo +- GPT-J +- BEiT +- SwinTransformer V1/V2 +- qwen + +These lists will grow longer as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project. + +For more details about compatibility between each optimization tool and each supported model, please refer to chapter Roadmap in our [develop document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md). + + + diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md index 573aab1c8a07..c59d75d321c0 100644 --- a/docs/source/zh-Hans/basics/booster_api.md +++ b/docs/source/zh-Hans/basics/booster_api.md @@ -1,4 +1,4 @@ -# booster 使用 +# Booster API 作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003) @@ -11,7 +11,8 @@ -- [使用 booster 训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [使用Booster在CIFAR-10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [使用Booster在RedPajama数据集上训练Llama-1/2](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) ## 简介 diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 6f731bfac1fc..0ad1cacab151 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -74,7 +74,7 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分: -1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。 +1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。 2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。 diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index 4dd45e8783c3..61982cbb8be9 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -2,6 +2,8 @@ 作者: Zhengda Bian, Yongbin Li +> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Shardformer](./shardformer.md)页面查阅更新。 + **前置教程** - [定义配置文件](../basics/define_your_config.md) - [并行配置](../basics/configure_parallelization.md) @@ -118,3 +120,5 @@ Output of the first linear layer: torch.Size([16, 512]) Output of the second linear layer: torch.Size([16, 256]) ``` 第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。 + + diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md new file mode 100644 index 000000000000..49aa23e2d06b --- /dev/null +++ b/docs/source/zh-Hans/features/shardformer.md @@ -0,0 +1,121 @@ +# Shardformer + +Author: [Baizhou Zhang](https://github.com/Fridge003) + +**预备知识** +- [并行技术](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Booster 插件](../basics/booster_plugins.md) + +**示例代码** +- [使用Shardformer进行张量并行训练](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) +- [通过HybridParallelPlugin使用Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) + +**相关论文** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) +- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) + + +## 简介 + +在训练LLaMa-2 70B或OPT 175B等大型Transformer模型时,为了满足GPU内存的限制,将大型模型划分为更小的分片的模型并行方法(包括张量并行以及流水线并行)是必不可少的。然而,对于不熟悉分布式训练的用户来说,手动剪切模型并重写其前向/反向逻辑可能很困难。与此同时,Huggingface transformers开源库正在逐渐成为用户模型来源的首选,大部分主流大型模型都已在Huggingface transformers模型库中开源。 + +出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。 + + +## Shardformer的工作原理 + +通常来说,Shardformer通过以下四种“替换”进行工作: + +1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 +分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 + +2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 + +3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 + +4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 +如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 + +所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 + +## 用法 + +### Shardformer的参数配置 + +Shardformer的配置由类`ShardConfig`的参数控制: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 + +### 启动Shardformer + +#### 1. 通过Booster启动Shardformer (推荐) + +通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。 + +更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 + + +#### 2. 通过Shardformer API启动Shardformer (不推荐) + +您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。 + +[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +是一个通过调用Shardformer的API启动`Shardformer`的示例。 + + +### 注意事项 + +1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。 + +2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 + +3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + 并且使用这些导入的类初始化模型。 + +## 支持信息 + +Shardformer目前支持的Huggingface Transformer模型: +- LlaMa-1/LlaMa-2 +- GPT2 +- BERT +- OPT +- BLOOM +- T5 +- ViT +- ChatGLM-2 6B +- Whisper + +Shardformer目前支持的优化工具: +- Flash Attention 2 +- JIT Fused Operator +- xFormers +- Fused Layer Normalization +- Sequence Parallel +- Sequence Overlap + +我们计划在不久后为Shardformer支持的模型: +- SAM +- Blip2 +- RoBERTa +- ALBERT +- ERNIE +- GPT Neo +- GPT-J +- BEiT +- SwinTransformer V1/V2 +- qwen + +随着未来更多模型和优化工具的出现,这些列表将会变得越来越长。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。 + +更多关于不同优化工具和模型之间兼容性的细节,请参考[Shardformer开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)中的Roadmap一节。 + + From 8c2dda74107631f959906fd3211e7c575ecd8540 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:17:32 +0800 Subject: [PATCH 10/58] [format] applied code formatting on changed files in pull request 4726 (#4727) Co-authored-by: github-actions --- README.md | 2 +- docs/README-zh-Hans.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 25d3b8f83f1e..42549ac55873 100644 --- a/README.md +++ b/README.md @@ -472,7 +472,7 @@ To cite this project, you can use the following BibTeX citation. } ``` -Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.

(back to top)

diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 41eebc59c493..bb5f49bc546b 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -453,7 +453,7 @@ Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们 } ``` -Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,等顶级会议录取为官方教程。

(返回顶端)

From 50e5602c2d6c8e25ad544cbecc38649e5257e7b8 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 15 Sep 2023 13:52:30 +0800 Subject: [PATCH 11/58] [doc] add shardformer support matrix/update tensor parallel documents (#4728) * add compatibility matrix for shardformer doc * update tp doc --- docs/source/en/features/1D_tensor_parallel.md | 80 +----- docs/source/en/features/2D_tensor_parallel.md | 86 +------ .../en/features/2p5D_tensor_parallel.md | 89 +------ docs/source/en/features/3D_tensor_parallel.md | 88 +------ docs/source/en/features/shardformer.md | 228 ++++++++++++++---- .../zh-Hans/features/1D_tensor_parallel.md | 84 +------ .../zh-Hans/features/2D_tensor_parallel.md | 84 +------ .../zh-Hans/features/2p5D_tensor_parallel.md | 91 +------ .../zh-Hans/features/3D_tensor_parallel.md | 90 +------ docs/source/zh-Hans/features/shardformer.md | 210 +++++++++++++--- 10 files changed, 388 insertions(+), 742 deletions(-) diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md index 79fe5ddea221..0f01cfd325e5 100644 --- a/docs/source/en/features/1D_tensor_parallel.md +++ b/docs/source/en/features/1D_tensor_parallel.md @@ -2,14 +2,12 @@ Author: Zhengda Bian, Yongbin Li -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Shardformer](./shardformer.md) for more information. - **Prerequisite** - [Define Your Configuration](../basics/define_your_config.md) - [Configure Parallelization](../basics/configure_parallelization.md) **Example Code** -- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) **Related Paper** - [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) @@ -44,79 +42,7 @@ Given $P$ processors, we present the theoretical computation and memory cost, as ## Usage -To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d'), -)) -``` -Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` - -Launch Colossal-AI on 2 GPUs and build the model. - -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([256, 512]) -Weight of the second linear layer: torch.Size([512, 256]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`. -Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`. - -We can run the model with some random inputs. -```python -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -torch.distributed.broadcast(x, src=0) # synchronize input - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Output of the first linear layer: torch.Size([16, 512]) -Output of the second linear layer: torch.Size([16, 256]) -``` -The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs. +1D tensor parallelism is implemented by `Shardformer` feature in the newest version of ColossalAI. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md index aae8cc9eef97..c79e7d196f8b 100644 --- a/docs/source/en/features/2D_tensor_parallel.md +++ b/docs/source/en/features/2D_tensor_parallel.md @@ -60,83 +60,9 @@ Given $P=q\times q$ processors, we present the theoretical computation and memor ## Usage -To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=4, mode='2d'), -)) -``` -Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 4 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([8, 128]) -Output of the first linear layer: torch.Size([8, 512]) -Output of the second linear layer: torch.Size([8, 128]) -``` -The activation tensors in 2D parallelism are all split in both row and column. -E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`. +Currently the newest version of ColossalAI doesn't support 2D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md index a81d14f10627..b3cbd1c7c727 100644 --- a/docs/source/en/features/2p5D_tensor_parallel.md +++ b/docs/source/en/features/2p5D_tensor_parallel.md @@ -58,86 +58,9 @@ Given $P=q \times q \times d$ processors, we present the theoretical computation ## Usage -To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='2.5d', depth=2), -)) - -``` -Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 8 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column. -E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. -Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input. +Currently the newest version of ColossalAI doesn't support 2.5D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md index 0e28f08b23c9..00e6c5fca40c 100644 --- a/docs/source/en/features/3D_tensor_parallel.md +++ b/docs/source/en/features/3D_tensor_parallel.md @@ -67,85 +67,9 @@ Given $P=q \times q \times q$ processors, we present the theoretical computation ## Usage -To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='3d'), -)) -``` -Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 8 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 256]) -Weight of the second linear layer: torch.Size([512, 64]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column. -E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. -Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different. +Currently the newest version of ColossalAI doesn't support 3D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 872d00e4a073..10e03e963a95 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -29,33 +29,6 @@ This module aims to make parallelization hassle-free for users who are not from Within a few lines of codes, users can turn a model into a state ready for distributed training. Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass. - -## How Shardformer Works - -Generally, Shardformer works through the following four kinds of *replacements*: - -1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. -The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. -Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. -Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. - -2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. -For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. - -3. Replacing the `forward` methods implemented by original Huggingface -Transformers libraries with our customized `forward` methods. -This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. -Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. - -4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). -By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. -To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. -All other parameters are released so as to liberate memory usage. -As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. - -All of these replacements are implemented with manually written policies and forward functions. -If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. - ## Usage ### Shardformer Configuration @@ -101,31 +74,187 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs. ``` when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. +## How Shardformer Works + +Generally, Shardformer works through the following four kinds of *replacements*: + +1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. +Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. +For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. + +3. Replacing the `forward` methods implemented by original Huggingface +Transformers libraries with our customized `forward` methods. +This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. +Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. + +4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). +By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. +To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. +All other parameters are released so as to liberate memory usage. +As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. + +All of these replacements are implemented with manually written policies and forward functions. +If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. ## Supporting Information -List of Huggingface transformers model families currently supported by Shardformer: -- LlaMa-1/LlaMa-2 -- GPT2 -- BERT -- OPT -- BLOOM -- T5 -- ViT -- ChatGLM-2 6B -- Whisper - -List of optimization tools currently supported by Shardformer: -- Flash Attention 2 -- JIT Fused Operator -- xFormers -- Fused Layer Normalization -- Sequence Parallel -- Sequence Overlap +Model/Feature Compatibility Matrix: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTensor
Parallel
Pipeline
Parallel
Lazy
Initialization
xFormersFlash
Attention 2
JIT Fused
Operators
Fused
LayerNorm
Sequence
Parallel
Sequence
Overlap
Llama V1/V2✔️✔️✔️✔️✔️✔️✔️
OPT✔️✔️✔️✔️✔️✔️✔️
BLOOM✔️✔️✔️✔️✔️✔️✔️✔️✔️
ChatGLM 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
BERT✔️✔️✔️✔️✔️✔️✔️✔️✔️
GPT 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
T5✔️✔️✔️✔️✔️✔️✔️
ViT✔️✔️✔️✔️✔️✔️
Whisper✔️✔️✔️✔️✔️✔️
SAM✔️✔️✔️✔️✔️
Blip2✔️✔️✔️✔️✔️
List of model families we plan to support in the near future: -- SAM -- Blip2 - RoBERTa - ALBERT - ERNIE @@ -135,9 +264,6 @@ List of model families we plan to support in the near future: - SwinTransformer V1/V2 - qwen -These lists will grow longer as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project. - -For more details about compatibility between each optimization tool and each supported model, please refer to chapter Roadmap in our [develop document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md). - +The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project. diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index 61982cbb8be9..93fe9ea99422 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -2,14 +2,12 @@ 作者: Zhengda Bian, Yongbin Li -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Shardformer](./shardformer.md)页面查阅更新。 - **前置教程** - [定义配置文件](../basics/define_your_config.md) - [并行配置](../basics/configure_parallelization.md) -**示例代码** -- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) +**示例代码**xw +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) **相关论文** - [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) @@ -43,82 +41,10 @@ $$ | :-: | :-: | :-: | :-: | :-: | | $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ | -## 使用 - -为了使模型能够实现一维张量并行, 如在2个 GPU 上, 我们需要配置如下的并行设置。 -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d'), -)) -``` - -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用1D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` - -在2个 GPU 上启动 Colossal-AI 并建立模型。 - -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([256, 512]) -Weight of the second linear layer: torch.Size([512, 256]) -``` -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过列-并行分割,它变成了 `[256, 512]`。 -同样地,第二个行并行层将权重 `[1024, 256]` 划分为 `[512, 256]`。 - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -torch.distributed.broadcast(x, src=0) # synchronize input +## 使用 -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Output of the first linear layer: torch.Size([16, 512]) -Output of the second linear layer: torch.Size([16, 256]) -``` -第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。 +在ColossalAI最新的版本中,1D张量并行由`Shardformer`功能实现。 +关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md index f163432ecceb..a8e5cf4bfb47 100644 --- a/docs/source/zh-Hans/features/2D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md @@ -60,82 +60,8 @@ $$ ## 使用 -为了使我们的模型能够实现二维张量并行,例如在4个 GPU 上,我们需要配置如下的并行设置。 -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=4, mode='2d'), -)) -``` -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在4个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([8, 128]) -Output of the first linear layer: torch.Size([8, 512]) -Output of the second linear layer: torch.Size([8, 128]) -``` -2D并行中的 activation 张量都是同时在行和列分割的。例如,第一个线性层的输出是 `[8, 512]`, 而第二层的输出为 `[8, 128]`。 +ColossalAI的最新版本还暂不支持2D张量并行,但2D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,2D张量并行的用法请参考[ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md index 5f15202729a7..6b0f1a301804 100644 --- a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md @@ -57,89 +57,8 @@ $$ ## 使用 -为了使我们的模型能够实现2.5D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 - -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='2.5d', depth=2), -)) - -``` - -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2.5D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 - -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在8个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` - -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2.5D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -2.5D并行中的 activation 张量都是同时在$d \times q$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 -注意,2.5D并行使用与2D并行相同的划分方法来处理权重,区别在于对输入的划分。 +ColossalAI的最新版本还暂不支持2.5D张量并行,但2.5D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,2.5D张量并行的用法请参考[ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md index 5ce0cdf6c068..f6154559ec28 100644 --- a/docs/source/zh-Hans/features/3D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md @@ -67,88 +67,8 @@ $$ ## 使用 -为了使我们的模型能够实现3D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 - -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='3d'), -)) -``` -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用3D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 - -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在8个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 256]) -Weight of the second linear layer: torch.Size([512, 64]) -``` - -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过3D并行划分后,它在每个 GPU 上变成了 `[128, 256]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 64]`. - -我们可以用一些随机输入来运行这个模型。 - -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -3D并行中的 activation 张量都是同时在$q^2$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 -注意,虽然这里3D并行的结果与2.5D并行的结果形状相同,但每个划分的内容是不同的。 +ColossalAI的最新版本还暂不支持3D张量并行,但3D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,3D张量并行的用法请参考[ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 49aa23e2d06b..e0d8df2c90c8 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -24,23 +24,6 @@ Author: [Baizhou Zhang](https://github.com/Fridge003) 出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。 - -## Shardformer的工作原理 - -通常来说,Shardformer通过以下四种“替换”进行工作: - -1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 -分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 - -2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 - -3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 - -4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 -如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 - -所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 - ## 用法 ### Shardformer的参数配置 @@ -81,30 +64,179 @@ Shardformer的配置由类`ShardConfig`的参数控制: ``` 并且使用这些导入的类初始化模型。 + +## Shardformer的工作原理 + +通常来说,Shardformer通过以下四种“替换”进行工作: + +1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 +分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 + +2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 + +3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 + +4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 +如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 + +所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 + + ## 支持信息 -Shardformer目前支持的Huggingface Transformer模型: -- LlaMa-1/LlaMa-2 -- GPT2 -- BERT -- OPT -- BLOOM -- T5 -- ViT -- ChatGLM-2 6B -- Whisper - -Shardformer目前支持的优化工具: -- Flash Attention 2 -- JIT Fused Operator -- xFormers -- Fused Layer Normalization -- Sequence Parallel -- Sequence Overlap +模型/功能 兼容性矩阵: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTensor
Parallel
Pipeline
Parallel
Lazy
Initialization
xFormersFlash
Attention 2
JIT Fused
Operators
Fused
LayerNorm
Sequence
Parallel
Sequence
Overlap
Llama V1/V2✔️✔️✔️✔️✔️✔️✔️
OPT✔️✔️✔️✔️✔️✔️✔️
BLOOM✔️✔️✔️✔️✔️✔️✔️✔️✔️
ChatGLM 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
BERT✔️✔️✔️✔️✔️✔️✔️✔️✔️
GPT 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
T5✔️✔️✔️✔️✔️✔️✔️
ViT✔️✔️✔️✔️✔️✔️
Whisper✔️✔️✔️✔️✔️✔️
SAM✔️✔️✔️✔️✔️
Blip2✔️✔️✔️✔️✔️
我们计划在不久后为Shardformer支持的模型: -- SAM -- Blip2 - RoBERTa - ALBERT - ERNIE @@ -114,8 +246,6 @@ Shardformer目前支持的优化工具: - SwinTransformer V1/V2 - qwen -随着未来更多模型和优化工具的出现,这些列表将会变得越来越长。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。 - -更多关于不同优化工具和模型之间兼容性的细节,请参考[Shardformer开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)中的Roadmap一节。 +随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。 From e4fc57c3de6204ac66df75aa1752db1ec284f31f Mon Sep 17 00:00:00 2001 From: digger yu Date: Fri, 15 Sep 2023 14:18:22 +0800 Subject: [PATCH 12/58] Optimized some syntax errors in the documentation and code under applications/ (#4127) Co-authored-by: flybird11111 <1829166702@qq.com> --- applications/Chat/README.md | 6 ++---- applications/Chat/coati/experience_maker/base.py | 2 +- applications/Chat/coati/models/lora.py | 2 +- applications/Chat/coati/ray/detached_replay_buffer.py | 2 +- applications/Chat/coati/ray/utils.py | 2 +- applications/Chat/evaluate/README.md | 2 +- applications/Chat/evaluate/gpt_evaluate.py | 8 ++++---- applications/Chat/examples/community/peft/README.md | 2 +- 8 files changed, 12 insertions(+), 14 deletions(-) diff --git a/applications/Chat/README.md b/applications/Chat/README.md index 5a1187ab503d..59e2c4548365 100644 --- a/applications/Chat/README.md +++ b/applications/Chat/README.md @@ -200,7 +200,6 @@ We provide an online inference server and a benchmark. We aim to run inference o We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference. Online inference server scripts can help you deploy your own services. - For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). ## Coati7B examples @@ -428,7 +427,7 @@ Thanks so much to all of our amazing contributors! -- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org) +- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)

@@ -469,8 +468,7 @@ Coati is developed by ColossalAI Team: - [ofey404](https://github.com/ofey404) - [Wenhao Chen](https://github.com/CWHer) -The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. - +The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. - [Zangwei Zheng](https://github.com/zhengzangw) - [Xue Fuzhao](https://github.com/XueFuzhao) diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index ff75852576c8..b4646f282f0c 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -10,7 +10,7 @@ @dataclass class Experience: """Experience is a batch of data. - These data should have the the sequence length and number of actions. + These data should have the sequence length and number of actions. Left padding for sequences is applied. Shapes of each tensor: diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 546f675d7d37..f1597da540a7 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -48,7 +48,7 @@ def __init__( def reset_parameters(self): if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero + # Initialize A with the default values for nn.Linear and set B to zero. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py index 7b9df2ee139b..e04bf5ccb881 100644 --- a/applications/Chat/coati/ray/detached_replay_buffer.py +++ b/applications/Chat/coati/ray/detached_replay_buffer.py @@ -16,7 +16,7 @@ class DetachedReplayBuffer: ''' Detached replay buffer. Share Experience across workers on the same node. - Therefore a trainer node is expected to have only one instance. + Therefore, a trainer node is expected to have only one instance. It is ExperienceMakerHolder's duty to call append(exp) method, remotely. Args: diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 761186b95ee5..391ffe7a91a9 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -116,7 +116,7 @@ def get_model_numel(model: nn.Module) -> int: def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list: target_receivers = [] if num_senders <= num_receivers or allow_idle_sender: - # a sender will send data to one or more than one receivers + # a sender will send data to one or more receivers # a receiver only has one sender for i in range(num_receivers): if i % num_senders == sender_idx: diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md index 68b03be16a30..0a97ae72f9d0 100644 --- a/applications/Chat/evaluate/README.md +++ b/applications/Chat/evaluate/README.md @@ -348,7 +348,7 @@ For example, if you want to add a new metric `persuasiveness` into category `bra

How can I add a new UniEval evaluation metric? -For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric. +For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown, and you may need some experiments to test whether the model is capable of evaluating this metric. ```python if task == 'data2text': diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py index f8cfb8d0f7e5..6fcbe63d0253 100644 --- a/applications/Chat/evaluate/gpt_evaluate.py +++ b/applications/Chat/evaluate/gpt_evaluate.py @@ -576,7 +576,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float: for key, value in logprobs.items(): # Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7". - # It is meaningless and thus we don't calculate probability. + # It is meaningless, and thus we don't calculate probability. if "bytes" in key: continue # results[0] is the score which corresponds to the key(predicted token). @@ -621,7 +621,7 @@ def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[st Args: model_name: name of the model for saving evaluation results. - gpt_evaluation_results: evaluations results for all of the model answers. + gpt_evaluation_results: evaluations results for all the model answers. save_path: path to save GPT evaluation statistics. """ @@ -641,7 +641,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav Args: model_name: name of the model for saving statistics. - evaluations: evaluations for all of the model answers. + evaluations: evaluations for all the model answers. save_path: path to save GPT evaluation statistics. """ @@ -663,7 +663,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav for evaluation in data: for metric in metrics: if evaluation["evaluation"][metric] == {}: - # This means after 3 retries, the server still returns an error and we set the score to 0. + # This means after 3 retries, the server still returns an error, and we set the score to 0. scores[metric].append(0) elif evaluation["evaluation"][metric]["logprobs"] is not None: scores[metric].append( diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md index 8b2edc48cd99..ada3a16296af 100644 --- a/applications/Chat/examples/community/peft/README.md +++ b/applications/Chat/examples/community/peft/README.md @@ -20,7 +20,7 @@ pip install . For SFT training, just call train_peft_sft.py -Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. +Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. For stage-3 rlhf training, call train_peft_prompts.py. Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. From 46162632e5dc8c0d7f6928b85d55b4d557615a8e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Sep 2023 14:32:04 +0800 Subject: [PATCH 13/58] [shardformer] update pipeline parallel document (#4725) * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document --- docs/source/en/features/pipeline_parallel.md | 222 +++++++++++------- .../zh-Hans/features/pipeline_parallel.md | 218 ++++++++++------- 2 files changed, 276 insertions(+), 164 deletions(-) diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md index 8b5f228a9e5e..cb19f9815bf2 100644 --- a/docs/source/en/features/pipeline_parallel.md +++ b/docs/source/en/features/pipeline_parallel.md @@ -1,14 +1,15 @@ # Pipeline Parallel -Author: Guangyang Lu, Hongxin Liu, Yongbin Li +Author: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) -- [Configure Parallelization](../basics/configure_parallelization.md) +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) +- [Use Booster to Training](../basics/booster_api.md) +- [Shardformer](../features/shardformer.md) +- [Plugin of Booster](../basics/booster_plugins.md) **Example Code** -- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) +- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py) **Related Paper** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -17,7 +18,7 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li ## Quick introduction -In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example. +In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example. ## Table Of Content @@ -25,7 +26,7 @@ In this tutorial we will cover: 1. Introduction of 1F1B pipeline. 2. Usage of non-interleaved and interleaved schedule. -3. Training ResNet with pipeline. +3. Finetune Bert with pipeline. ## Introduction of 1F1B pipeline @@ -60,101 +61,158 @@ In this schedule, each device can perform computation for multiple subsets of la This mode is both memory-efficient and time-efficient. -## Usage of non-interleaved and interleaved schedule +## Colossal-AI's Implementation -In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`). +In Colossal-AI, pipeline parallelism relies on the `scheduler` and [`Shardformer`](../features/shardformer.md). We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler. -You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you. +In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using `OneForwardOneBackwardSchedule` which is default scheduler used in `HybridParallelPlugin`, and `InterleavedSchedule` will be integrated later. -## Training ResNet with pipeline +You can customize your parallel strategy by setting parameters for the `HybridParallelPlugin`. -Let's build the `ResNet` model first with Colossal PipelinableContext: +For more usage details, please refer to the [documentation](../basics/booster_plugins.md) for `HybridParallelPlugin`. + +## Fine-tune Bert with pipeline + +First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion: ```python -import os -from typing import Callable, List, Optional, Type, Union +import argparse +from typing import Callable, List, Union + import torch import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + import colossalai -import colossalai.nn as col_nn +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from colossalai.context import ParallelMode -from colossalai.pipeline.pipelinable import PipelinableContext +# Define some config +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +coordinator = DistCoordinator() + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'. +def _criterion(outputs, inputs): + return outputs.loss + +# Define optimizer +lr = LEARNING_RATE +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] -from titans.dataloader.cifar10 import build_cifar -from torchvision.models import resnet50 -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) -# Define some config -BATCH_SIZE = 64 -NUM_EPOCHS = 2 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - -# Train -disable_existing_loggers() -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(backend=args.backend, config=CONFIG) -logger = get_dist_logger() -pipelinable = PipelinableContext() - -# build model -with pipelinable: - model = resnet50() -``` -Define an execution sequence. -```python -exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', - (lambda x: torch.flatten(x, 1), "behind"), 'fc' -] -pipelinable.to_layer_list(exec_seq) +# Define lr_scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + + +# Define Bert model +model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda() + +# Define a dataloader +data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) +train_dataloader = data_builder.train_dataloader() ``` -Partition the model into pipeline. +Define a booster with the `HybridParallelPlugin`. ```python -model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) +booster = Booster(plugin=plugin) ``` -In this tutorial, we use `Trainer` to train `ResNet`: +Boost these train componts with the booster created. ```python -# build criterion -criterion = nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# build dataloader -root = os.environ.get('DATA', './data') -train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) - -lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) -engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader, - lr_scheduler) -timer = MultiTimer() +model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) +``` -trainer = Trainer(engine=engine, timer=timer, logger=logger) +Train the model at last. -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) -] - -trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) +```python +# Define a train function +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + # convert train_dataloader to a iterator + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + +# Train model +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` -We use `2` pipeline stages and the batch will be split into `4` micro batches. +We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value) diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md index 1497dc399f6c..e688020556d8 100644 --- a/docs/source/zh-Hans/features/pipeline_parallel.md +++ b/docs/source/zh-Hans/features/pipeline_parallel.md @@ -1,14 +1,15 @@ # 流水并行 -作者: Guangyang Lu, Hongxin Liu, Yongbin Li +作者: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) -- [并行配置](../basics/configure_parallelization.md) +- [并行技术](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Shardformer](../features/shardformer.md) +- [Booster 插件](../basics/booster_plugins.md) **示例代码** -- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) +- [使用pipeline并行策略微调Bert](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py) **相关论文** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -17,7 +18,7 @@ ## 快速预览 -在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 ResNet 和 CIFAR 为例. +在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 Bert 和 Glue数据集 为例. ## 目录 @@ -25,7 +26,7 @@ 1. 介绍 1F1B 流水线; 2. 使用非交错和交错 schedule; -3. 使用流水线训练 ResNet。 +3. 使用流水线微调 Bert ## 认识 1F1B 流水线 @@ -59,101 +60,154 @@ 这种模式既节省内存又节省时间。 -## 使用schedule +## Colossal-AI中的实现 -在 Colossal-AI 中, 我们提供非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`)schedule。 +在 Colossal-AI 中,流水线并行依赖于 `scheduler` 和 `Shardformer`。我们提供了非交错的(`OneForwardOneBackwardSchedule`)和交错的(`InterleavedSchedule`)两种调度方式。而 Shardformer 实现了对模型的层分割,并替换了模型的 `forward` 函数,使其与调度器兼容。 -你只需要在配置文件中,设置 `NUM_MICRO_BATCHES` 并在你想使用交错schedule的时候,设置 `NUM_CHUNKS`。 如果你确定性地知道每个管道阶段的输出张量的形状,而且形状都是一样的,你可以设置 `tensor_shape` 以进一步减少通信。否则,你可以忽略 `tensor_shape` , 形状将在管道阶段之间自动交换。 我们将会根据用户提供的配置文件,生成一个合适schedule来支持用户的流水并行训练。 +在 Colossal-AI 中,`HybridParallelPlugin` 封装了流水线执行策略。它管理流水线并行通信组和一个 `scheduler`。当使用此插件增强模型时,模型的层将通过调用 `shardformer.optimize` 函数进行分割,然后调用 `execute_pipeline` 使用 `scheduler` 来分别执行模型的各个部分。 `HybridParallelPlugin`暂时只支持`OneForwardOneBackwardSchedule`, `InterleavedSchedule`将会在不久后支持。 -## 使用流水线训练 ResNet +您可以通过设置 `HybridParallelPlugin` 的参数来自定义您的并行策略。更多使用细节请参考`HybridParallelPlugin`的[使用文档](../basics/booster_plugins.md)。 -我们首先用Colossal PipelinableContext方式建立 `ResNet` 模型: +## 使用流水线微调 Bert模型 + +首先我们定义好需要的训练组件,包括`model`, `dataloader`, `optimizer`, `lr_scheduler`, `criterion` 等: ```python -import os -from typing import Callable, List, Optional, Type, Union +import argparse +from typing import Callable, List, Union + import torch import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + import colossalai -import colossalai.nn as col_nn +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from colossalai.context import ParallelMode -from colossalai.pipeline.pipelinable import PipelinableContext +# Define some config +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +coordinator = DistCoordinator() + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + +# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'. +def _criterion(outputs, inputs): + return outputs.loss + +# Define optimizer +lr = LEARNING_RATE +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] -from titans.dataloader.cifar10 import build_cifar -from torchvision.models import resnet50 -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) -# Define some config -BATCH_SIZE = 64 -NUM_EPOCHS = 2 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - -# Train -disable_existing_loggers() -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(backend=args.backend, config=CONFIG) -logger = get_dist_logger() -pipelinable = PipelinableContext() - -# build model -with pipelinable: - model = resnet50() + +# Define lr_scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + + +# Define Bert model +model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda() + +# Define a dataloader +data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) +train_dataloader = data_builder.train_dataloader() ``` -给定切分顺序,module直接给出name,部分函数需要手动添加。 +使用`HybridParallelPlugin`初始化一个booster. ```python -exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', - (lambda x: torch.flatten(x, 1), "behind"), 'fc' -] -pipelinable.to_layer_list(exec_seq) +plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) +booster = Booster(plugin=plugin) ``` -将模型切分成流水线阶段。 +使用`booster`将优化特性注入到训练组件中。 ```python -model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) ``` -我们使用`Trainer`训练`ResNet`: +最后训练模型 ```python -# build criterion -criterion = nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# build dataloader -root = os.environ.get('DATA', './data') -train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) - -lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) -engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader, - lr_scheduler) -timer = MultiTimer() - -trainer = Trainer(engine=engine, timer=timer, logger=logger) - -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) -] - -trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) +# Define a train function +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + # convert train_dataloader to a iterator + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + +# Train model +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` -我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。 +我们使用 `2` 个流水段,并且 batch 将被切分为 `1` 个 micro batches。(这些参数都可根据实际情况设置为合适的值) From cd4e61d149db3b98435cf6c90a389d8d1dff21e6 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Fri, 15 Sep 2023 15:52:18 +0800 Subject: [PATCH 14/58] [legacy] remove deterministic data loader test --- .../test_deterministic_dataloader.py | 73 ------------------- 1 file changed, 73 deletions(-) delete mode 100644 tests/test_data/test_deterministic_dataloader.py diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py deleted file mode 100644 index 283b5cc35279..000000000000 --- a/tests/test_data/test_deterministic_dataloader.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from pathlib import Path - -import pytest -import torch -import torch.distributed as dist -from torchvision import datasets, transforms - -import colossalai -from colossalai.context import Config, ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader - -CONFIG = Config( - dict( - train_data=dict( - dataset=dict( - type='CIFAR10', - root=Path(os.environ['DATA']), - train=True, - download=True, - ), - dataloader=dict(num_workers=2, batch_size=2, shuffle=True), - ), - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, - )) - - -def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') - colossalai.launch(**dist_args) - - # build dataset - transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] - transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) - - # build dataloader - dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) - - data_iter = iter(dataloader) - img, label = data_iter.next() - img = img[0] - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - img_to_compare = img.clone() - else: - img_to_compare = img - dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - # this is without sampler - # this should be false if data parallel sampler to given to the dataloader - assert torch.equal(img, - img_to_compare), 'Same image was distributed across ranks and expected it to be the same' - torch.cuda.empty_cache() - - -@rerun_if_address_is_in_use() -def test_data_sampler(): - spawn(run_data_sampler, 4) - - -if __name__ == '__main__': - test_data_sampler() From 6a03c933a0ef43090c8add9303898287b1482d74 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 15 Sep 2023 16:09:32 +0800 Subject: [PATCH 15/58] [shardformer] update seq parallel document (#4730) * update doc of seq parallel * fix typo --- docs/source/en/features/shardformer.md | 17 +++++++++++++++-- docs/source/zh-Hans/features/shardformer.md | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 10e03e963a95..ca23f07421d1 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -1,6 +1,6 @@ # Shardformer -Author: [Baizhou Zhang](https://github.com/Fridge003) +Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer) **Prerequisite** - [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) @@ -16,7 +16,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003) - [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) - [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) - [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) - +- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198) ## Introduction @@ -74,6 +74,18 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs. ``` when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. +### Sequence Parallelism + +Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism just use with 1D tensor parallelism to to further reduce the memory occupation of activations in computations. + +1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradient from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the device. + +2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs in sequence dimension during forward and Reduce-Scatter to splite the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to splite the output of row linear layer of tensor parallel to all devices in sequence dimension, and All-Gather to get the whole gradient during backward. + +3. The implementation of All-Reduce using NCCL adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared to sequence parallelism and tensor parallelism, it does not introduce additional communication overhead. + +4. One important thing to note is that when using sequence parallelism with 'Column Linear' of tensor parallelism,, during the backward computation of gradients, the complete input needs to be obtained. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, shape like $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, in the implementation, it is possible to overlap the gradient computation with the All-Gather communication operation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`). + ## How Shardformer Works Generally, Shardformer works through the following four kinds of *replacements*: @@ -100,6 +112,7 @@ As a result, the optimizer will only compute the states corresponding to these p All of these replacements are implemented with manually written policies and forward functions. If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. + ## Supporting Information Model/Feature Compatibility Matrix: diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index e0d8df2c90c8..7de0c41c10d7 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -1,6 +1,6 @@ # Shardformer -Author: [Baizhou Zhang](https://github.com/Fridge003) +Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer) **预备知识** - [并行技术](../concepts/paradigms_of_parallelism.md) @@ -16,6 +16,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003) - [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) - [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) - [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) +- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198) ## 简介 @@ -65,6 +66,19 @@ Shardformer的配置由类`ShardConfig`的参数控制: 并且使用这些导入的类初始化模型。 +### 序列并行 Sequence Parallelism + +在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。 + +1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。 + +2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。 + +3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。 + +4. 需要注意的一点是,在张量并行的 “Column Linear” 中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。 + + ## Shardformer的工作原理 通常来说,Shardformer通过以下四种“替换”进行工作: From 608cffaed3821bacdfce7c44cdf09e6cd38d32c2 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:12:46 +0800 Subject: [PATCH 16/58] [example] add gpt2 HybridParallelPlugin example (#4653) * add gpt2 HybridParallelPlugin example * update readme and testci * update test ci * fix test_ci bug * update requirements * add requirements * update requirements * add requirement * rename file --- examples/language/gpt/README.md | 10 + .../language/gpt/hybridparallelism/data.py | 127 ++++++++ .../gpt/hybridparallelism/finetune.py | 299 ++++++++++++++++++ .../language/gpt/hybridparallelism/run.sh | 5 + examples/language/gpt/requirements.txt | 5 + examples/language/gpt/test_ci.sh | 3 + 6 files changed, 449 insertions(+) create mode 100644 examples/language/gpt/hybridparallelism/data.py create mode 100644 examples/language/gpt/hybridparallelism/finetune.py create mode 100644 examples/language/gpt/hybridparallelism/run.sh diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md index 47d24a4d69cb..03679e66404a 100644 --- a/examples/language/gpt/README.md +++ b/examples/language/gpt/README.md @@ -65,6 +65,16 @@ Titans provides a customized GPT model, which uses distributed operators as buil In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP. You can switch parallel strategies using a config file. +### Hybridparallelism + +Hybridparallelism provides a user friendly plugin to set multiple parallelism method for training and inference. In [./hybridparallelism], we provide a n example to finetune gpt2 using Hybridparallelism. + +Quick run +```bash +cd ./hybridparallelism +bash run.sh +``` + ## Performance Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e. diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py new file mode 100644 index 000000000000..981cedcca8c2 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/data.py @@ -0,0 +1,127 @@ +import datasets +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + + def val_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, + max_length=self.max_seq_length, + padding='max_length', + truncation=True) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py new file mode 100644 index 000000000000..03e5ec91b3fe --- /dev/null +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -0,0 +1,299 @@ +import argparse +from contextlib import nullcontext +from typing import Callable, List, Union + +import evaluate +import torch +import torch.distributed as dist +import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate_model( + model: nn.Module, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + labels = batch["labels"] + if use_pipeline: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) + + if is_pp_last_stage: + logits = outputs["outputs"]["logits"] + val_loss = outputs["loss"] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + object_list = [None, None] + dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) + accum_loss.add_(object_list[1].to(get_current_device())) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master() and results is not None: + results['loss'] = accum_loss.item() / coordinator.world_size + + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], + help="plugin to use") + parser.add_argument( + "--model_type", + type=str, + default="gpt2", + help="only gpt2 now", + ) + parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + args = parser.parse_args() + + if args.model_type == 'gpt2': + model_name = "gpt2" + else: + raise RuntimeError + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # gpt2 pretrained model + + cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + + if model_name == "gpt2": + model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + else: + raise RuntimeError + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) + + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, booster, coordinator) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/hybridparallelism/run.sh b/examples/language/gpt/hybridparallelism/run.sh new file mode 100644 index 000000000000..679cbbf9b1e2 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/run.sh @@ -0,0 +1,5 @@ +# load via internet +torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" + +# load from local +# torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" --pretrained_path "your/path/to/pretrained_model" diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt index ef58bb76bfc8..1a173f228aee 100644 --- a/examples/language/gpt/requirements.txt +++ b/examples/language/gpt/requirements.txt @@ -1,2 +1,7 @@ transformers >= 4.23 colossalai +evaluate +tqdm +scipy +scikit-learn +numpy diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh index d67c17229e71..b9e4e43a8d35 100644 --- a/examples/language/gpt/test_ci.sh +++ b/examples/language/gpt/test_ci.sh @@ -1,2 +1,5 @@ set -x +pip install -r requirements.txt + cd gemini && bash test_ci.sh +cd ../hybridparallelism && bash run.sh From 451c3465fbde69695270bfd8f7ad26bebc079432 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 15 Sep 2023 17:39:10 +0800 Subject: [PATCH 17/58] [doc] polish shardformer doc (#4735) * arrange position of chapters * fix typos in seq parallel doc --- docs/source/en/features/shardformer.md | 167 ++++++++++---------- docs/source/zh-Hans/features/shardformer.md | 141 ++++++++--------- 2 files changed, 153 insertions(+), 155 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index ca23f07421d1..4abfff8a3cfa 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -29,90 +29,6 @@ This module aims to make parallelization hassle-free for users who are not from Within a few lines of codes, users can turn a model into a state ready for distributed training. Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass. -## Usage - -### Shardformer Configuration - -The configuration of Shardformer is controlled by class `ShardConfig`: - -{{ autodoc:colossalai.shardformer.ShardConfig }} - -If you want to enable Apex Fused Layernorm, please install `apex`. -If you want to enable the usage of flash attention, please install `flash_attn`. -In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. - -### Enabling Shardformer - -#### 1. Enabling Shardformer Through Booster (Recommended) - -Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. -The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. - -More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md). - -[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline. - - -#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) - -You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. - -[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) -is an example on how to trigger `Shardformer` through calling Shardformer APIs. - - -### Precautions - -1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method. - -2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. - -3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. - -### Sequence Parallelism - -Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism just use with 1D tensor parallelism to to further reduce the memory occupation of activations in computations. - -1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradient from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the device. - -2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs in sequence dimension during forward and Reduce-Scatter to splite the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to splite the output of row linear layer of tensor parallel to all devices in sequence dimension, and All-Gather to get the whole gradient during backward. - -3. The implementation of All-Reduce using NCCL adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared to sequence parallelism and tensor parallelism, it does not introduce additional communication overhead. - -4. One important thing to note is that when using sequence parallelism with 'Column Linear' of tensor parallelism,, during the backward computation of gradients, the complete input needs to be obtained. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, shape like $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, in the implementation, it is possible to overlap the gradient computation with the All-Gather communication operation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`). - -## How Shardformer Works - -Generally, Shardformer works through the following four kinds of *replacements*: - -1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. -The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. -Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. -Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. - -2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. -For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. - -3. Replacing the `forward` methods implemented by original Huggingface -Transformers libraries with our customized `forward` methods. -This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. -Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. - -4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). -By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. -To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. -All other parameters are released so as to liberate memory usage. -As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. - -All of these replacements are implemented with manually written policies and forward functions. -If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. - - ## Supporting Information Model/Feature Compatibility Matrix: @@ -279,4 +195,87 @@ List of model families we plan to support in the near future: The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project. +## Usage + +### Shardformer Configuration + +The configuration of Shardformer is controlled by class `ShardConfig`: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +If you want to enable Apex Fused Layernorm, please install `apex`. +If you want to enable the usage of flash attention, please install `flash_attn`. +In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. + +### Enabling Shardformer + +#### 1. Enabling Shardformer Through Booster (Recommended) + +Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. +The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. + +More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md). + +[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline. + + +#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) + +You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. + +[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +is an example on how to trigger `Shardformer` through calling Shardformer APIs. + +### Precautions + +1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method. + +2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. + +3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. + +## How Shardformer Works + +Generally, Shardformer works through the following four kinds of *replacements*: + +1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. +Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. +For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. + +3. Replacing the `forward` methods implemented by original Huggingface +Transformers libraries with our customized `forward` methods. +This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. +Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. + +4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). +By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. +To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. +All other parameters are released so as to liberate memory usage. +As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. + +All of these replacements are implemented with manually written policies and forward functions. +If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. + +### Sequence Parallelism + +Sequence parallelism is a special optimization method supported by `Shardformer`. Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism is only used along with 1D tensor parallelism to further reduce memory occupation of activation tensors during computation. + +1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradients from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the devices. + +2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs along sequence dimension during forward, and Reduce-Scatter to split the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to split the output of `Row Linear` layer of tensor parallel to all devices along sequence dimension, and All-Gather to get the whole gradient during backward. + +3. NCCL's implementation of All-Reduce adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared with sequence parallelism and tensor parallelism, it does not introduce additional communication overhead. + +4. One important thing to note is that when using sequence parallelism along with `Column Linear` module of tensor parallelism, the complete input needs to be obtained during the backward computation of gradients. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, in the shape of $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, it is possible to overlap the gradient computation with the All-Gather communication operation in our implementation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`). + + diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 7de0c41c10d7..fe0e7a63ba44 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -25,77 +25,6 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. 出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。 -## 用法 - -### Shardformer的参数配置 - -Shardformer的配置由类`ShardConfig`的参数控制: - -{{ autodoc:colossalai.shardformer.ShardConfig }} - -如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 - -### 启动Shardformer - -#### 1. 通过Booster启动Shardformer (推荐) - -通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。 - -更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 - - -#### 2. 通过Shardformer API启动Shardformer (不推荐) - -您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。 - -[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) -是一个通过调用Shardformer的API启动`Shardformer`的示例。 - - -### 注意事项 - -1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。 - -2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 - -3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - 并且使用这些导入的类初始化模型。 - - -### 序列并行 Sequence Parallelism - -在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。 - -1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。 - -2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。 - -3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。 - -4. 需要注意的一点是,在张量并行的 “Column Linear” 中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。 - - -## Shardformer的工作原理 - -通常来说,Shardformer通过以下四种“替换”进行工作: - -1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 -分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 - -2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 - -3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 - -4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 -如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 - -所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 - - ## 支持信息 模型/功能 兼容性矩阵: @@ -262,4 +191,74 @@ Shardformer的配置由类`ShardConfig`的参数控制: 随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。 +## 用法 + +### Shardformer的参数配置 + +Shardformer的配置由类`ShardConfig`的参数控制: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 + +### 启动Shardformer + +#### 1. 通过Booster启动Shardformer (推荐) + +通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。 + +更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 + + +#### 2. 通过Shardformer API启动Shardformer (不推荐) + +您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。 + +[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +是一个通过调用Shardformer的API启动`Shardformer`的示例。 + + +### 注意事项 + +1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。 + +2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 + +3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + 并且使用这些导入的类初始化模型。 + + +## Shardformer的工作原理 + +通常来说,Shardformer通过以下四种“替换”进行工作: + +1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 +分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 + +2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 + +3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 + +4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 +如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 + +所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 + +### 序列并行 Sequence Parallelism + +序列并行是`Shardformer`支持的一种特殊的优化方法。在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。 + +1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。 + +2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。 + +3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。 + +4. 需要注意的一点是,在张量并行的 `Column Linear` 层中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。 + + From ac2797996b362e5bded4d0eec18ef96efc12b086 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:53:13 +0800 Subject: [PATCH 18/58] [shardformer] add custom policy in hybrid parallel plugin (#4718) * add custom policy * update assert --- .../booster/plugin/hybrid_parallel_plugin.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3fbeebcc4110..d15245523226 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,6 +22,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, custom_policy: Policy) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. """ def __init__(self, @@ -302,7 +306,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -326,6 +331,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -405,7 +411,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.custom_policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: From 4c4482f3adb56943a150b8b7ed886e2218fc98d5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Sep 2023 18:45:44 +0800 Subject: [PATCH 19/58] [example] llama2 add fine-tune example (#4673) * [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example --- .../hybrid_parallel_checkpoint_io.py | 4 +- examples/language/bert/finetune.py | 7 +- examples/language/llama2/README.md | 39 ++- examples/language/llama2/finetune.py | 295 ++++++++++++++++++ examples/language/llama2/pretrain.py | 79 +++-- examples/language/llama2/requirements.txt | 2 +- examples/language/opt/README.md | 7 +- examples/language/opt/requirements.txt | 4 +- 8 files changed, 402 insertions(+), 35 deletions(-) create mode 100644 examples/language/llama2/finetune.py diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 6eee3ace0308..270fd8564754 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -13,6 +13,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from colossalai.cluster import DistCoordinator from colossalai.interface import OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO @@ -71,6 +72,7 @@ def __init__(self, self.verbose = verbose self.working_to_master_map = None self.master_to_working_map = None + self.coordinator = DistCoordinator() @staticmethod def _model_sharder(model: nn.Module, @@ -655,7 +657,7 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) - state_[k] = v.detach().clone().cpu() + state_[k] = v.detach().clone().cpu() return state_ diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 2e8780806f19..fb6e4332c2f9 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -129,14 +129,13 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) total_step = len(train_dataloader) model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) - with tqdm(range(total_step), - desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar: # Forward pass for _ in pbar: if use_pipeline: @@ -192,13 +191,13 @@ def main(): model_name = "albert-xxlarge-v2" else: raise RuntimeError + # ============================== # Launch Distributed Environment # ============================== colossalai.launch_from_torch(config={}, seed=42) coordinator = DistCoordinator() - # local_batch_size = BATCH_SIZE // coordinator.world_size lr = LEARNING_RATE * coordinator.world_size # ============================== diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index c8fc86d29d97..83ef99b57d42 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -92,7 +92,7 @@ Make sure master node can access all nodes (including itself) by ssh without pas Here is details about CLI arguments: - Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). +- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). - Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. - Number of epochs: `-e`, `--num_epochs`. The default value is 1. - Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. @@ -195,3 +195,40 @@ If you run the above command successfully, you will get the following results: year={2023} } ``` + + +# Fine-tune Llama2 + +We also provide a example to fine-tune llama2 in `finetune.py`, + +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag. +- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). +- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`. +- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`. +- Number of epochs: `-e`, `--num_epochs`. The default value is 1. +- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. +- Learning rate: `--lr`. The default value is 3e-4. +- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +- Max length: `-l`, `--max_length`. The default value is 4096. +- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`. +- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. +- Gradient clipping: `--gradient_clipping`. The default value is 1.0. +- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. +- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. + + +```shell +torchrun --standalone --nproc_per_node 8 finetune.py \ + --plugin "hybrid_parallel" \ + --dataset "yizhongw/self_instruct" \ + --model_path "/path/llama" \ + --task_name "super_natural_instructions" \ + --save_dir "/path/output" +``` diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py new file mode 100644 index 000000000000..0efbf193c9a9 --- /dev/null +++ b/examples/language/llama2/finetune.py @@ -0,0 +1,295 @@ +import argparse +import math +import os +import resource +from contextlib import nullcontext +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from attn import SUPPORT_XFORMERS, replace_xformers +from data_utils import load_json, prepare_dataloader, save_json +from datasets import load_dataset +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f'{numel / B:.2f} B' + elif numel >= M: + return f'{numel / M:.2f} M' + elif numel >= K: + return f'{numel / K:.2f} K' + else: + return f'{numel}' + + +def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): + texts = [sample['prompt'] + sample['completion'] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} + data['labels'] = data['input_ids'].clone() + return data + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, + batch_size: int, coordinator: DistCoordinator, save_dir: str): + save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') + os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) + running_states = { + 'epoch': epoch, + 'step': step, + 'sample_start_index': step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, 'running_states.json')) + + +def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, + load_dir: str) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, 'model')) + booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) + running_states = load_json(os.path.join(load_dir, 'running_states.json')) + return running_states['epoch'], running_states['step'], running_states['sample_start_index'] + + +def _criterion(outputs, inputs): + return outputs.loss + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune") + parser.add_argument('-p', + '--plugin', + choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], + default='gemini', + help='Choose which plugin to use') + parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path') + parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run') + parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') + parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') + parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') + parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') + parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') + parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') + parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') + parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') + parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') + parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') + parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') + parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + args = parser.parse_args() + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == 'gemini': + plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) + elif args.plugin == 'gemini_auto': + plugin = GeminiPlugin(precision=args.mixed_precision, + placement_policy='auto', + initial_scale=2**16, + max_norm=args.grad_clip) + elif args.plugin == 'zero2': + plugin = LowLevelZeroPlugin(stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip) + elif args.plugin == 'zero2_cpu': + plugin = LowLevelZeroPlugin(stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip) + elif args.plugin == 'hybrid_parallel': + # modify the param accordingly, default configuration is for llama2-7b + plugin = HybridParallelPlugin(tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision='fp32', + initial_scale=1) + else: + raise ValueError(f'Unknown plugin {args.plugin}') + + booster = Booster(plugin=plugin) + + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + + # ============================== + # Initialize Tensorboard + # ============================== + if print_flag: + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Model, Optimizer and LR Scheduler + # ============================== + + config = LlamaConfig.from_pretrained(args.model_path) + # use lazy init when using GeminiPlugin + init_ctx = LazyInitContext( + default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + + with init_ctx: + model = LlamaForCausalLM(config) + + # ============================== + # Initialize Tokenizer, Dataset and Dataloader + # ============================== + tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 + tokenizer.pad_token = tokenizer.unk_token + + dataset = load_dataset(args.dataset, args.task_name) + train_ds = dataset['train'] + dataloader = prepare_dataloader(train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_finetune, + tokenizer=tokenizer, + max_length=args.max_length)) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + if args.flash_attention: + assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + + optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) + total_step = args.num_epochs * len(dataloader) + lr_scheduler = CosineAnnealingWarmupLR(optimizer, + total_steps=total_step, + warmup_steps=math.ceil(total_step * 0.03), + eta_min=0.1 * args.lr) + default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, + optimizer, + dataloader=dataloader, + lr_scheduler=lr_scheduler) + torch.set_default_dtype(torch.float) + + booster.load_model(model, args.model_path) + + coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master( + f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + + # load checkpoint if specified + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load is not None: + coordinator.print_on_master('Loading checkpoint') + start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) + coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + + num_steps_per_epoch = len(dataloader) + + # if resume training, set the sampler start index to the correct value + dataloader.sampler.set_start_index(sampler_start_idx) + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch) + step_nums = num_steps_per_epoch - start_step + dataloader_iter = iter(dataloader) + + with tqdm(range(step_nums), + desc=f'Epoch {epoch}', + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step) as pbar: + for step in pbar: + if use_pipeline: + outputs = booster.execute_pipeline(dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + loss = outputs["loss"] + else: + batch = next(dataloader_iter) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if not use_pipeline: + all_reduce_mean(loss) + if print_flag: + pbar.set_postfix({'loss': loss.item()}) + writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + + if args.save_interval > 0 and (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f'Saving checkpoint') + save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, + args.save_dir) + coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + + +if __name__ == '__main__': + main() diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index b72a3019692e..0eeac4035401 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -21,7 +21,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str: return f'{numel}' -def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): +def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): texts = [sample['text'] for sample in batch] data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} data['labels'] = data['input_ids'].clone() return data @@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: return running_states['epoch'], running_states['step'], running_states['sample_start_index'] +def _criterion(outputs, inputs): + return outputs.loss + + def main(): # ============================== # Parse Arguments @@ -112,7 +117,7 @@ def main(): parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') parser.add_argument('-p', '--plugin', - choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'], + choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], default='gemini', help='Choose which plugin to use') parser.add_argument('-d', @@ -142,13 +147,6 @@ def main(): colossalai.launch_from_torch({}) coordinator = DistCoordinator() - # ============================== - # Initialize Tensorboard - # ============================== - if coordinator.is_master(): - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - # ============================== # Initialize Booster # ============================== @@ -170,11 +168,32 @@ def main(): initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip) + elif args.plugin == 'hybrid_parallel': + # modify the param accordingly, default configuration is for llama2-7b + plugin = HybridParallelPlugin(tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision='fp32', + initial_scale=1) else: raise ValueError(f'Unknown plugin {args.plugin}') booster = Booster(plugin=plugin) + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + + # ============================== + # Initialize Tensorboard + # ============================== + if print_flag: + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + # ============================== # Initialize Tokenizer, Dataset and Dataloader # ============================== @@ -188,12 +207,15 @@ def main(): batch_size=args.batch_size, shuffle=True, drop_last=True, - collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length)) + collate_fn=partial(tokenize_batch_for_pretrain, + tokenizer=tokenizer, + max_length=args.max_length)) # ============================== # Initialize Model, Optimizer and LR Scheduler # ============================== config = MODEL_CONFIGS[args.config] + # use lazy init when using GeminiPlugin init_ctx = LazyInitContext( default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() @@ -236,27 +258,42 @@ def main(): coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') num_steps_per_epoch = len(dataloader) + # if resume training, set the sampler start index to the correct value dataloader.sampler.set_start_index(sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch) - with tqdm(enumerate(dataloader), + step_nums = num_steps_per_epoch - start_step + dataloader_iter = iter(dataloader) + + with tqdm(range(step_nums), desc=f'Epoch {epoch}', - disable=not coordinator.is_master(), + disable=not print_flag, total=num_steps_per_epoch, initial=start_step) as pbar: - for step, batch in pbar: - batch = {k: v.cuda() for k, v in batch.items()} - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) + for step in pbar: + if use_pipeline: + outputs = booster.execute_pipeline(dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + loss = outputs["loss"] + else: + batch = next(dataloader_iter) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() lr_scheduler.step() optimizer.zero_grad() - all_reduce_mean(loss) - pbar.set_postfix({'loss': loss.item()}) - if coordinator.is_master(): + if not use_pipeline: + all_reduce_mean(loss) + if print_flag: + pbar.set_postfix({'loss': loss.item()}) writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) if args.save_interval > 0 and (step + 1) % args.save_interval == 0: diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt index 3ddf21ffe534..6b475682dad0 100644 --- a/examples/language/llama2/requirements.txt +++ b/examples/language/llama2/requirements.txt @@ -1,4 +1,4 @@ -colossalai>=0.3.0 +colossalai>=0.3.2 datasets numpy torch>=1.12.0,<=2.0.0 diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index 37e1ff4d9008..af1e794374ed 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -23,9 +23,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) ## Our Modifications We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before -the tokenization). +the tokenization). -We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. +We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, HybridParallelPlugin and GeminiPlugin. ## Run Demo @@ -48,6 +48,3 @@ You can run benchmark for OPT model by running the following script: bash run_benchmark.sh ``` The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing. - - - diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt index 4422216e6a1c..45bfbc37195f 100644 --- a/examples/language/opt/requirements.txt +++ b/examples/language/opt/requirements.txt @@ -1,4 +1,4 @@ -colossalai >= 0.1.12 +colossalai >= 0.3.2 torch >= 1.8.1 datasets >= 1.8.0 -transformers >= 4.20.0 \ No newline at end of file +transformers >= 4.30.2 From d151dcab740eaae784333c93d85100c3641bd115 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 15 Sep 2023 21:04:07 +0800 Subject: [PATCH 20/58] [doc] explaination of loading large pretrained models (#4741) --- docs/source/en/basics/booster_checkpoint.md | 24 +++++++++++++++++++ .../zh-Hans/basics/booster_checkpoint.md | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md index 4ef35dc9a9bb..ea6c11ae2cdc 100644 --- a/docs/source/en/basics/booster_checkpoint.md +++ b/docs/source/en/basics/booster_checkpoint.md @@ -19,6 +19,30 @@ Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way. +If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through `from_pretrained` on a single device, a recommended way is to download the pretrained weights to a local directory, and use `booster.load` to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode: +```python +from colossalai.lazy import LazyInitContext +from huggingface_hub import snapshot_download +... + +# Initialize model under lazy init context +init_ctx = LazyInitContext(default_device=get_current_device) +with init_ctx: + model = LlamaForCausalLM(config) + +... + +# Wrap the model through Booster.boost +model, optimizer, _, _, _ = booster.boost(model, optimizer) + +# download huggingface pretrained model to local directory. +model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp") + +# load model using booster.load +booster.load(model, model_dir) +... +``` + ## Optimizer Checkpoint {{ autodoc:colossalai.booster.Booster.save_optimizer }} diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md index 02557ad47d56..1ff2e330521c 100644 --- a/docs/source/zh-Hans/basics/booster_checkpoint.md +++ b/docs/source/zh-Hans/basics/booster_checkpoint.md @@ -19,6 +19,30 @@ 模型在加载前必须被 `colossalai.booster.Booster` 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。 +如果您想从Huggingface加载预训练好的模型,但模型太大以至于无法在单个设备上通过“from_pretrained”直接加载,推荐的方法是将预训练的模型权重下载到本地,并在封装模型后使用`booster.load`直接从本地路径加载。为了避免内存不足,模型需要在`Lazy Initialization`的环境下初始化。以下是示例伪代码: +```python +from colossalai.lazy import LazyInitContext +from huggingface_hub import snapshot_download +... + +# Initialize model under lazy init context +init_ctx = LazyInitContext(default_device=get_current_device) +with init_ctx: + model = LlamaForCausalLM(config) + +... + +# Wrap the model through Booster.boost +model, optimizer, _, _, _ = booster.boost(model, optimizer) + +# download huggingface pretrained model to local directory. +model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp") + +# load model using booster.load +booster.load(model, model_dir) +... +``` + ## 优化器 Checkpoint From 32e7f99416c846402d6098419777edee3ddbce7b Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:44:27 +0800 Subject: [PATCH 21/58] [kernel] update triton init #4740 (#4740) --- colossalai/kernel/triton/__init__.py | 30 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 5840ad2918be..75812db036a9 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,12 +1,20 @@ -from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd -from .copy_kv_cache_dest import copy_kv_cache_to_dest -from .fused_layernorm import layer_norm -from .rms_norm import rmsnorm_forward -from .rotary_embedding_kernel import rotary_embedding_fwd -from .softmax import softmax -from .token_attention_kernel import token_attention_fwd +try: + import triton + HAS_TRITON = True -__all__ = [ - "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", - "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" -] + from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd + from .copy_kv_cache_dest import copy_kv_cache_to_dest + from .fused_layernorm import layer_norm + from .rms_norm import rmsnorm_forward + from .rotary_embedding_kernel import rotary_embedding_fwd + from .softmax import softmax + from .token_attention_kernel import token_attention_fwd + + __all__ = [ + "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", + "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" + ] + +except ImportError: + HAS_TRITON = False + print("Triton is not installed. Please install Triton to use Triton kernels.") From b5f9e37c709656b286940f1b5e05abddfa257e3d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 18 Sep 2023 16:31:06 +0800 Subject: [PATCH 22/58] [legacy] clean up legacy code (#4743) * [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci --- .github/workflows/doc_test_on_pr.yml | 2 +- .github/workflows/doc_test_on_schedule.yml | 2 +- .../workflows/example_check_on_dispatch.yml | 2 +- .github/workflows/example_check_on_pr.yml | 2 +- .../workflows/example_check_on_schedule.yml | 2 +- colossalai/__init__.py | 11 +- colossalai/amp/__init__.py | 54 -- colossalai/amp/naive_amp/__init__.py | 60 -- .../auto_parallel/offload/amp_optimizer.py | 4 +- colossalai/checkpoint_io/utils.py | 4 - colossalai/cli/benchmark/__init__.py | 28 - colossalai/cli/benchmark/benchmark.py | 105 --- colossalai/cli/benchmark/models.py | 18 - colossalai/cli/benchmark/utils.py | 159 ----- colossalai/cli/cli.py | 2 - colossalai/context/__init__.py | 12 +- colossalai/context/moe_context.py | 7 +- colossalai/core.py | 6 - colossalai/fx/passes/shard_1d_pass.py | 16 +- colossalai/initialize.py | 326 +-------- colossalai/legacy/__init__.py | 9 + colossalai/legacy/amp/__init__.py | 54 ++ colossalai/{ => legacy}/amp/amp_type.py | 0 .../{ => legacy}/amp/apex_amp/__init__.py | 0 .../{ => legacy}/amp/apex_amp/apex_amp.py | 6 +- colossalai/legacy/amp/naive_amp/__init__.py | 60 ++ .../amp/naive_amp/_fp16_optimizer.py | 9 +- .../{ => legacy}/amp/naive_amp/_utils.py | 0 .../{ => legacy}/amp/naive_amp/naive_amp.py | 10 +- .../{ => legacy}/amp/torch_amp/__init__.py | 0 .../amp/torch_amp/_grad_scaler.py | 4 +- .../{ => legacy}/amp/torch_amp/torch_amp.py | 6 +- colossalai/legacy/communication/collective.py | 14 +- colossalai/legacy/communication/p2p.py | 4 +- colossalai/legacy/communication/p2p_v2.py | 4 +- colossalai/legacy/communication/ring.py | 4 +- colossalai/legacy/communication/utils.py | 4 +- colossalai/{ => legacy}/constants.py | 0 colossalai/legacy/context/__init__.py | 4 + .../{ => legacy}/context/parallel_context.py | 66 +- .../{ => legacy}/context/parallel_mode.py | 0 .../process_group_initializer/__init__.py | 2 +- .../initializer_1d.py | 2 +- .../initializer_2d.py | 2 +- .../initializer_2p5d.py | 2 +- .../initializer_3d.py | 2 +- .../initializer_data.py | 0 .../initializer_model.py | 0 .../initializer_pipeline.py | 0 .../initializer_sequence.py | 0 .../initializer_tensor.py | 0 .../process_group_initializer.py | 0 .../{ => legacy}/context/random/__init__.py | 0 .../{ => legacy}/context/random/_helper.py | 12 +- .../context/random/seed_manager.py | 10 +- colossalai/legacy/core.py | 6 + colossalai/legacy/engine/_base_engine.py | 10 +- .../_gradient_accumulation.py | 6 +- .../_data_parallel_gradient_handler.py | 4 +- .../gradient_handler/_moe_gradient_handler.py | 4 +- .../_pipeline_parallel_gradient_handler.py | 2 +- .../_sequence_parallel_gradient_handler.py | 4 +- .../engine/schedule/_pipeline_schedule.py | 10 +- .../engine/schedule/_pipeline_schedule_v2.py | 4 +- colossalai/{ => legacy}/global_variables.py | 0 colossalai/legacy/initialize.py | 472 +++++++++++++ colossalai/legacy/nn/__init__.py | 1 - colossalai/legacy/nn/_ops/__init__.py | 10 +- colossalai/legacy/nn/_ops/_utils.py | 5 +- colossalai/legacy/nn/_ops/addmm.py | 90 --- colossalai/legacy/nn/_ops/batch_norm.py | 33 - colossalai/legacy/nn/_ops/element_wise.py | 250 ------- colossalai/legacy/nn/_ops/embedding.py | 142 ---- colossalai/legacy/nn/_ops/embedding_bag.py | 127 ---- colossalai/legacy/nn/_ops/layernorm.py | 28 - colossalai/legacy/nn/_ops/linear.py | 171 ----- colossalai/legacy/nn/_ops/loss.py | 51 -- colossalai/legacy/nn/_ops/view.py | 96 --- colossalai/legacy/nn/layer/base_layer.py | 4 +- .../nn/layer/colossalai_layer/dropout.py | 2 +- .../legacy/nn/layer/parallel_1d/_operation.py | 2 +- .../legacy/nn/layer/parallel_1d/_utils.py | 4 +- .../legacy/nn/layer/parallel_1d/layers.py | 10 +- .../legacy/nn/layer/parallel_2d/_operation.py | 40 +- .../legacy/nn/layer/parallel_2d/_utils.py | 6 +- .../legacy/nn/layer/parallel_2d/layers.py | 11 +- .../nn/layer/parallel_2p5d/_operation.py | 34 +- .../legacy/nn/layer/parallel_2p5d/_utils.py | 6 +- .../legacy/nn/layer/parallel_2p5d/layers.py | 10 +- .../legacy/nn/layer/parallel_3d/_operation.py | 44 +- .../legacy/nn/layer/parallel_3d/_utils.py | 12 +- .../legacy/nn/layer/parallel_3d/layers.py | 18 +- .../nn/layer/parallel_sequence/_operation.py | 4 +- .../nn/layer/parallel_sequence/layers.py | 6 +- colossalai/legacy/nn/layer/utils/common.py | 6 +- colossalai/legacy/nn/layer/vanilla/layers.py | 6 +- .../nn/layer/wrapper/pipeline_wrapper.py | 4 +- colossalai/legacy/nn/loss/__init__.py | 2 +- colossalai/legacy/nn/loss/loss_1d.py | 4 +- colossalai/legacy/nn/loss/loss_2d.py | 4 +- colossalai/legacy/nn/loss/loss_2p5d.py | 4 +- colossalai/legacy/nn/loss/loss_3d.py | 4 +- colossalai/legacy/nn/metric/accuracy_3d.py | 2 +- .../legacy/nn/parallel/data_parallel.py | 6 +- .../parallel_cached_embedding.py | 3 +- .../parallel_cached_embedding_tablewise.py | 2 +- ..._cached_embedding_tablewise_split_cache.py | 2 +- .../legacy/nn/parallel/layers/colo_module.py | 4 +- .../legacy/nn/parallel/layers/embedding.py | 2 +- .../legacy/nn/parallel/layers/linear.py | 2 +- .../legacy/nn/parallel/layers/module_utils.py | 3 +- colossalai/legacy/pipeline/__init__.py | 4 + .../{ => legacy}/pipeline/layer_spec.py | 6 +- .../legacy/pipeline/middleware/__init__.py | 3 + .../pipeline/middleware/adaptor/__init__.py | 2 +- .../pipeline/middleware/adaptor/fx.py | 34 +- .../{ => legacy}/pipeline/middleware/topo.py | 86 +-- .../{ => legacy}/pipeline/pipelinable.py | 26 +- .../pipeline/pipeline_process_group.py | 6 +- colossalai/legacy/pipeline/rpc/__init__.py | 4 + .../pipeline/rpc/_pipeline_base.py | 6 +- .../pipeline/rpc/_pipeline_schedule.py | 8 +- colossalai/{ => legacy}/pipeline/rpc/utils.py | 2 +- colossalai/{ => legacy}/pipeline/utils.py | 0 colossalai/legacy/tensor/__init__.py | 17 + .../{ => legacy}/tensor/compute_spec.py | 0 colossalai/{ => legacy}/tensor/const.py | 0 .../{ => legacy}/tensor/dist_spec_mgr.py | 6 +- colossalai/{ => legacy}/tensor/distspec.py | 0 colossalai/{ => legacy}/tensor/op_wrapper.py | 5 +- .../{ => legacy}/tensor/process_group.py | 0 colossalai/{ => legacy}/tensor/tensor_spec.py | 4 +- colossalai/legacy/trainer/_trainer.py | 3 +- .../legacy/trainer/hooks/_checkpoint_hook.py | 2 +- colossalai/legacy/trainer/hooks/_log_hook.py | 11 +- .../legacy/trainer/hooks/_metric_hook.py | 7 +- colossalai/legacy/utils/__init__.py | 53 ++ .../utils/activation_checkpoint.py | 16 +- .../legacy/utils/checkpoint/__init__.py | 3 + .../utils/checkpoint/module_checkpoint.py | 21 +- .../{ => legacy}/utils/checkpoint/utils.py | 128 ++-- .../{ => legacy}/utils/checkpointing.py | 8 +- colossalai/legacy/utils/common.py | 434 ++++++++++++ .../utils/data_sampler/__init__.py | 0 .../utils/data_sampler/base_sampler.py | 0 .../data_sampler/data_parallel_sampler.py | 4 +- colossalai/{ => legacy}/utils/memory.py | 18 +- .../{ => legacy}/utils/profiler/__init__.py | 0 .../{ => legacy}/utils/profiler/extention.py | 0 .../utils/profiler/legacy/__init__.py | 12 +- .../utils/profiler/legacy/comm_profiler.py | 619 +++++++++--------- .../utils/profiler/legacy/pcie_profiler.py | 298 ++++----- .../utils/profiler/legacy/prof_utils.py | 263 ++++---- .../{ => legacy}/utils/profiler/profiler.py | 4 +- .../profiler/stateful_tensor_mem_extention.py | 2 +- .../{zero/legacy => legacy/zero}/__init__.py | 0 .../legacy => legacy/zero}/gemini/__init__.py | 0 .../zero}/gemini/gemini_context.py | 0 .../zero}/gemini/ophooks/__init__.py | 0 .../gemini/ophooks/_shard_grad_ophook.py | 0 .../gemini/ophooks/_shard_param_ophook.py | 0 .../gemini/ophooks/runtime_mem_tracer_hook.py | 2 +- .../zero}/gemini/ophooks/utils.py | 0 .../zero}/gemini/paramhooks/__init__.py | 0 .../zero}/gemini/paramhooks/_param_hookmgr.py | 0 .../zero}/gemini/stateful_tensor.py | 0 .../zero}/gemini/stateful_tensor_mgr.py | 0 .../zero}/gemini/tensor_placement_policy.py | 2 +- .../zero}/gemini/tensor_utils.py | 0 .../zero}/init_ctx/__init__.py | 0 .../zero}/init_ctx/init_context.py | 12 +- .../zero}/shard_utils/__init__.py | 0 .../zero}/shard_utils/base_shard_strategy.py | 2 +- .../bucket_tensor_shard_strategy.py | 2 +- .../zero}/shard_utils/commons.py | 0 .../shard_utils/tensor_shard_strategy.py | 8 +- .../zero}/sharded_model/__init__.py | 0 .../zero}/sharded_model/_utils.py | 2 +- .../zero}/sharded_model/reduce_scatter.py | 0 .../zero}/sharded_model/sharded_model_v2.py | 24 +- .../zero}/sharded_model/utils.py | 2 +- .../zero}/sharded_model/zero_hook.py | 8 +- .../zero}/sharded_optim/__init__.py | 0 .../zero}/sharded_optim/sharded_optim_v2.py | 18 +- .../zero}/sharded_param/__init__.py | 0 .../zero}/sharded_param/sharded_param.py | 4 +- .../zero}/sharded_param/sharded_tensor.py | 2 +- colossalai/logging/logger.py | 8 - colossalai/nn/layer/__init__.py | 2 +- colossalai/nn/layer/moe/experts.py | 4 +- colossalai/nn/layer/moe/layers.py | 2 +- colossalai/nn/loss/__init__.py | 2 +- colossalai/nn/optimizer/__init__.py | 7 +- .../nn/optimizer/colossalai_optimizer.py | 44 -- colossalai/pipeline/__init__.py | 13 +- colossalai/pipeline/middleware/__init__.py | 3 - colossalai/pipeline/rpc/__init__.py | 4 - colossalai/pipeline/schedule/__init__.py | 2 + colossalai/tensor/__init__.py | 11 +- colossalai/utils/__init__.py | 59 +- colossalai/utils/checkpoint/__init__.py | 3 - colossalai/utils/common.py | 438 +------------ colossalai/utils/cuda.py | 11 +- colossalai/utils/moe.py | 106 +-- colossalai/zero/gemini/colo_init_context.py | 3 +- .../zero/gemini/memory_tracer/__init__.py | 5 +- .../memory_tracer/chunk_memstats_collector.py | 2 +- .../gemini/memory_tracer/memory_monitor.py | 3 +- .../memory_tracer/memstats_collector.py | 2 +- .../memory_tracer/runtime_mem_tracer.py | 6 +- colossalai/zero/gemini/placement_policy.py | 2 +- colossalai/zero/low_level/_utils.py | 3 - docs/README.md | 2 +- .../advanced_tutorials/add_your_parallel.md | 2 +- .../train_gpt_using_hybrid_parallelism.md | 2 +- docs/source/en/basics/command_line_tool.md | 22 +- .../advanced_tutorials/add_your_parallel.md | 2 +- .../train_gpt_using_hybrid_parallelism.md | 2 +- .../zh-Hans/basics/command_line_tool.md | 20 +- .../roberta/pretraining/pretrain_utils.py | 2 +- .../roberta/pretraining/run_pretraining.py | 2 +- .../roberta/pretraining/utils/exp_util.py | 2 +- examples/images/dreambooth/test_ci.sh | 42 +- .../dreambooth/train_dreambooth_colossalai.py | 9 +- .../train_dreambooth_colossalai_lora.py | 4 +- .../auto_parallel/auto_parallel_with_gpt.py | 2 +- .../pipeline_parallel/train_gpt_pp.py | 8 +- examples/language/gpt/gemini/run_gemini.sh | 7 +- .../language/gpt/gemini/train_gpt_demo.py | 4 +- examples/language/gpt/test_ci.sh | 2 +- examples/language/gpt/titans/model/embed.py | 4 +- examples/language/gpt/titans/model/gpt1d.py | 4 +- .../gpt/titans/model/pipeline_gpt1d.py | 6 +- examples/language/gpt/titans/train_gpt.py | 6 +- .../auto_parallel_with_resnet.py | 2 +- examples/tutorial/auto_parallel/test_ci.sh | 8 +- examples/tutorial/hybrid_parallel/config.py | 2 +- examples/tutorial/hybrid_parallel/train.py | 6 +- .../tutorial/large_batch_optimizer/config.py | 2 +- .../tutorial/large_batch_optimizer/test_ci.sh | 7 +- .../tutorial/large_batch_optimizer/train.py | 2 +- examples/tutorial/opt/opt/colossalai_zero.py | 2 +- examples/tutorial/opt/opt/context.py | 4 +- examples/tutorial/opt/opt/run_clm.py | 9 +- examples/tutorial/opt/opt/test_ci.sh | 32 +- examples/tutorial/sequence_parallel/config.py | 2 +- .../sequence_parallel/data/__init__.py | 34 +- .../sequence_parallel/data/bert_helper.py | 23 +- .../data/datasets/bert_dataset.py | 4 +- .../data/datasets/data_samplers.py | 8 +- .../data/tokenizer/tokenizer.py | 30 +- .../sequence_parallel/loss_func/bert_loss.py | 28 +- .../loss_func/cross_entropy.py | 12 +- .../tutorial/sequence_parallel/model/bert.py | 8 +- .../sequence_parallel/model/layers/head.py | 25 +- .../model/layers/preprocess.py | 9 +- .../tutorial/sequence_parallel/test_ci.sh | 5 +- examples/tutorial/sequence_parallel/train.py | 9 +- tests/components_to_test/resnet.py | 13 +- .../test_C_solver_consistency.py | 2 +- .../test_ckpt_torchvision.py | 2 +- .../test_compatibility_with_gemini.py | 5 +- .../test_autochunk_alphafold_utils.py | 2 +- .../test_autochunk_diffuser_utils.py | 2 +- .../test_autochunk_vit_utils.py | 2 +- tests/test_cluster/test_process_group_mesh.py | 6 +- .../test_context/configs/parallel_2d_init.py | 10 - .../configs/parallel_2p5d_init.py | 11 - .../test_context/configs/parallel_3d_init.py | 10 - tests/test_device/test_init_logical_pg.py | 5 +- .../test_activation_checkpoint_codegen.py | 2 +- ...st_nested_activation_checkpoint_codegen.py | 2 +- .../test_codegen/test_offload_codegen.py | 2 +- tests/test_fx/test_parallel_1d.py | 2 +- .../test_pipeline/test_topo/topo_utils.py | 33 +- .../test_amp/test_naive_fp16.py | 4 +- .../test_amp/test_torch_fp16.py | 4 +- .../test_comm/test_boardcast_send_recv_v2.py | 6 +- tests/test_legacy/test_comm/test_comm.py | 6 +- .../test_comm/test_object_list_p2p.py | 6 +- .../test_comm/test_object_list_p2p_v2.py | 6 +- .../test_context/configs/parallel_2d_init.py | 4 + .../configs/parallel_2p5d_init.py | 4 + .../test_context/configs/parallel_3d_init.py | 4 + .../test_context/test_hybrid_parallel.py | 10 +- .../test_data/test_cifar10_dataset.py | 0 .../test_data/test_data_parallel_sampler.py | 9 +- .../test_deterministic_dataloader.py | 74 +++ tests/test_legacy/test_engine/test_engine.py | 20 +- .../test_engine/test_gradient_accumluation.py | 19 +- .../test_1d/checks_1d/check_layer_1d.py | 9 +- .../test_layers/test_1d/test_1d.py | 4 +- .../test_2d/checks_2d/check_layer_2d.py | 7 +- .../test_2d/checks_2d/check_operation_2d.py | 7 +- .../test_layers/test_2d/test_2d.py | 4 +- .../test_2p5d/checks_2p5d/check_layer_2p5d.py | 7 +- .../checks_2p5d/check_operation_2p5d.py | 7 +- .../test_layers/test_2p5d/test_2p5d.py | 4 +- .../test_3d/checks_3d/check_layer_3d.py | 7 +- .../test_layers/test_3d/test_3d.py | 4 +- .../test_layers/test_cache_embedding.py | 5 +- .../checks_seq/check_layer_seq.py | 4 +- .../test_sequence/test_sequence.py | 6 +- .../test_pipeline/rpc_test_utils.py | 4 +- .../test_pipeline/test_cuda_rpc_chimera.py | 6 +- .../test_pipeline/test_cuda_rpc_optimizer.py | 9 +- .../test_pipeline/test_cuda_rpc_pipeline.py | 4 +- .../test_cuda_rpc_value_correctness.py | 7 +- .../test_pipeline/test_middleware_1f1b.py | 8 +- .../test_pipeline/test_pipelinable.py | 2 +- .../test_pipeline_process_group.py | 4 +- .../test_tensor/common_utils/__init__.py | 2 +- .../test_tensor/common_utils/_utils.py | 6 +- .../test_tensor/core/test_dist_spec_mgr.py | 4 +- .../test_tensor/test_parameter.py | 4 +- .../test_trainer/test_pipeline/test_p2p.py | 6 +- .../test_pipeline/test_pipeline_schedule.py | 10 +- .../test_trainer_with_non_pipe_schedule.py | 17 +- .../test_trainer_with_pipe_schedule.py | 22 +- .../test_activation_checkpointing.py | 6 +- .../test_checkpoint/test_checkpoint_1d.py | 12 +- .../test_checkpoint/test_checkpoint_2d.py | 12 +- .../test_checkpoint/test_checkpoint_2p5d.py | 12 +- .../test_checkpoint/test_checkpoint_3d.py | 12 +- .../test_utils/test_memory.py | 4 +- .../test_utils/test_norm_gradient_clipping.py | 6 +- .../test_zero}/test_commons.py | 6 +- tests/test_moe/test_kernel.py | 4 +- tests/test_moe/test_moe_zero_optim.py | 2 +- tests/test_tensor/test_comm_spec_apply.py | 5 +- .../test_dtensor/test_comm_spec.py | 6 +- tests/test_tensor/test_mix_gather.py | 4 +- .../test_zero_gradient_clippling.py | 111 ---- .../test_zero/test_gemini/test_chunk_mgrv2.py | 2 - tests/test_zero/test_gemini/test_fwd_bwd.py | 4 +- .../test_gemini/test_gemini_use_rmt.py | 2 +- tests/test_zero/test_gemini/test_grad_clip.py | 4 +- tests/test_zero/test_gemini/test_inference.py | 4 +- tests/test_zero/test_gemini/test_optim.py | 4 +- .../test_gemini/test_zeroddp_state_dict.py | 2 +- .../test_gemini/test_zerooptim_state_dict.py | 2 +- .../test_zero/test_low_level/test_zero_tp.py | 96 --- 342 files changed, 2917 insertions(+), 4180 deletions(-) delete mode 100644 colossalai/cli/benchmark/__init__.py delete mode 100644 colossalai/cli/benchmark/benchmark.py delete mode 100644 colossalai/cli/benchmark/models.py delete mode 100644 colossalai/cli/benchmark/utils.py delete mode 100644 colossalai/core.py create mode 100644 colossalai/legacy/amp/__init__.py rename colossalai/{ => legacy}/amp/amp_type.py (100%) rename colossalai/{ => legacy}/amp/apex_amp/__init__.py (100%) rename colossalai/{ => legacy}/amp/apex_amp/apex_amp.py (86%) create mode 100644 colossalai/legacy/amp/naive_amp/__init__.py rename colossalai/{ => legacy}/amp/naive_amp/_fp16_optimizer.py (97%) rename colossalai/{ => legacy}/amp/naive_amp/_utils.py (100%) rename colossalai/{ => legacy}/amp/naive_amp/naive_amp.py (94%) rename colossalai/{ => legacy}/amp/torch_amp/__init__.py (100%) rename colossalai/{ => legacy}/amp/torch_amp/_grad_scaler.py (99%) rename colossalai/{ => legacy}/amp/torch_amp/torch_amp.py (95%) rename colossalai/{ => legacy}/constants.py (100%) create mode 100644 colossalai/legacy/context/__init__.py rename colossalai/{ => legacy}/context/parallel_context.py (88%) rename colossalai/{ => legacy}/context/parallel_mode.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/__init__.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_1d.py (96%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_2d.py (98%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_2p5d.py (99%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_3d.py (99%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_data.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_model.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_pipeline.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_sequence.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/initializer_tensor.py (100%) rename colossalai/{ => legacy}/context/process_group_initializer/process_group_initializer.py (100%) rename colossalai/{ => legacy}/context/random/__init__.py (100%) rename colossalai/{ => legacy}/context/random/_helper.py (90%) rename colossalai/{ => legacy}/context/random/seed_manager.py (86%) create mode 100644 colossalai/legacy/core.py rename colossalai/{ => legacy}/global_variables.py (100%) create mode 100644 colossalai/legacy/initialize.py delete mode 100644 colossalai/legacy/nn/_ops/addmm.py delete mode 100644 colossalai/legacy/nn/_ops/batch_norm.py delete mode 100644 colossalai/legacy/nn/_ops/element_wise.py delete mode 100644 colossalai/legacy/nn/_ops/embedding.py delete mode 100644 colossalai/legacy/nn/_ops/embedding_bag.py delete mode 100644 colossalai/legacy/nn/_ops/layernorm.py delete mode 100644 colossalai/legacy/nn/_ops/linear.py delete mode 100644 colossalai/legacy/nn/_ops/loss.py delete mode 100644 colossalai/legacy/nn/_ops/view.py create mode 100644 colossalai/legacy/pipeline/__init__.py rename colossalai/{ => legacy}/pipeline/layer_spec.py (97%) create mode 100644 colossalai/legacy/pipeline/middleware/__init__.py rename colossalai/{ => legacy}/pipeline/middleware/adaptor/__init__.py (62%) rename colossalai/{ => legacy}/pipeline/middleware/adaptor/fx.py (92%) rename colossalai/{ => legacy}/pipeline/middleware/topo.py (95%) rename colossalai/{ => legacy}/pipeline/pipelinable.py (93%) rename colossalai/{ => legacy}/pipeline/pipeline_process_group.py (98%) create mode 100644 colossalai/legacy/pipeline/rpc/__init__.py rename colossalai/{ => legacy}/pipeline/rpc/_pipeline_base.py (99%) rename colossalai/{ => legacy}/pipeline/rpc/_pipeline_schedule.py (97%) rename colossalai/{ => legacy}/pipeline/rpc/utils.py (98%) rename colossalai/{ => legacy}/pipeline/utils.py (100%) create mode 100644 colossalai/legacy/tensor/__init__.py rename colossalai/{ => legacy}/tensor/compute_spec.py (100%) rename colossalai/{ => legacy}/tensor/const.py (100%) rename colossalai/{ => legacy}/tensor/dist_spec_mgr.py (97%) rename colossalai/{ => legacy}/tensor/distspec.py (100%) rename colossalai/{ => legacy}/tensor/op_wrapper.py (97%) rename colossalai/{ => legacy}/tensor/process_group.py (100%) rename colossalai/{ => legacy}/tensor/tensor_spec.py (79%) create mode 100644 colossalai/legacy/utils/__init__.py rename colossalai/{ => legacy}/utils/activation_checkpoint.py (95%) create mode 100644 colossalai/legacy/utils/checkpoint/__init__.py rename colossalai/{ => legacy}/utils/checkpoint/module_checkpoint.py (90%) rename colossalai/{ => legacy}/utils/checkpoint/utils.py (91%) rename colossalai/{ => legacy}/utils/checkpointing.py (98%) create mode 100644 colossalai/legacy/utils/common.py rename colossalai/{ => legacy}/utils/data_sampler/__init__.py (100%) rename colossalai/{ => legacy}/utils/data_sampler/base_sampler.py (100%) rename colossalai/{ => legacy}/utils/data_sampler/data_parallel_sampler.py (98%) rename colossalai/{ => legacy}/utils/memory.py (95%) rename colossalai/{ => legacy}/utils/profiler/__init__.py (100%) rename colossalai/{ => legacy}/utils/profiler/extention.py (100%) rename colossalai/{ => legacy}/utils/profiler/legacy/__init__.py (77%) rename colossalai/{ => legacy}/utils/profiler/legacy/comm_profiler.py (96%) rename colossalai/{ => legacy}/utils/profiler/legacy/pcie_profiler.py (95%) rename colossalai/{ => legacy}/utils/profiler/legacy/prof_utils.py (94%) rename colossalai/{ => legacy}/utils/profiler/profiler.py (97%) rename colossalai/{ => legacy}/utils/profiler/stateful_tensor_mem_extention.py (98%) rename colossalai/{zero/legacy => legacy/zero}/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/gemini_context.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/_shard_grad_ophook.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/_shard_param_ophook.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/runtime_mem_tracer_hook.py (98%) rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/utils.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/paramhooks/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/paramhooks/_param_hookmgr.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/stateful_tensor.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/stateful_tensor_mgr.py (100%) rename colossalai/{zero/legacy => legacy/zero}/gemini/tensor_placement_policy.py (98%) rename colossalai/{zero/legacy => legacy/zero}/gemini/tensor_utils.py (100%) rename colossalai/{zero/legacy => legacy/zero}/init_ctx/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/init_ctx/init_context.py (96%) rename colossalai/{zero/legacy => legacy/zero}/shard_utils/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/shard_utils/base_shard_strategy.py (90%) rename colossalai/{zero/legacy => legacy/zero}/shard_utils/bucket_tensor_shard_strategy.py (97%) rename colossalai/{zero/legacy => legacy/zero}/shard_utils/commons.py (100%) rename colossalai/{zero/legacy => legacy/zero}/shard_utils/tensor_shard_strategy.py (90%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/_utils.py (97%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/reduce_scatter.py (100%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/sharded_model_v2.py (97%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/utils.py (92%) rename colossalai/{zero/legacy => legacy/zero}/sharded_model/zero_hook.py (94%) rename colossalai/{zero/legacy => legacy/zero}/sharded_optim/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/sharded_optim/sharded_optim_v2.py (97%) rename colossalai/{zero/legacy => legacy/zero}/sharded_param/__init__.py (100%) rename colossalai/{zero/legacy => legacy/zero}/sharded_param/sharded_param.py (96%) rename colossalai/{zero/legacy => legacy/zero}/sharded_param/sharded_tensor.py (94%) delete mode 100644 colossalai/nn/optimizer/colossalai_optimizer.py delete mode 100644 colossalai/pipeline/middleware/__init__.py delete mode 100644 colossalai/pipeline/rpc/__init__.py delete mode 100644 colossalai/utils/checkpoint/__init__.py delete mode 100644 tests/test_context/configs/parallel_2d_init.py delete mode 100644 tests/test_context/configs/parallel_2p5d_init.py delete mode 100644 tests/test_context/configs/parallel_3d_init.py rename tests/{ => test_legacy}/test_amp/test_naive_fp16.py (94%) rename tests/{ => test_legacy}/test_amp/test_torch_fp16.py (95%) create mode 100644 tests/test_legacy/test_context/configs/parallel_2d_init.py create mode 100644 tests/test_legacy/test_context/configs/parallel_2p5d_init.py create mode 100644 tests/test_legacy/test_context/configs/parallel_3d_init.py rename tests/{ => test_legacy}/test_context/test_hybrid_parallel.py (95%) rename tests/{ => test_legacy}/test_data/test_cifar10_dataset.py (100%) rename tests/{ => test_legacy}/test_data/test_data_parallel_sampler.py (87%) create mode 100644 tests/test_legacy/test_data/test_deterministic_dataloader.py rename tests/{ => test_legacy}/test_pipeline/rpc_test_utils.py (97%) rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_chimera.py (94%) rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_optimizer.py (89%) rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_pipeline.py (87%) rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_value_correctness.py (91%) rename tests/{ => test_legacy}/test_pipeline/test_middleware_1f1b.py (94%) rename tests/{ => test_legacy}/test_pipeline/test_pipelinable.py (96%) rename tests/{ => test_legacy}/test_pipeline/test_pipeline_process_group.py (91%) rename tests/{ => test_legacy}/test_tensor/common_utils/__init__.py (95%) rename tests/{ => test_legacy}/test_tensor/common_utils/_utils.py (93%) rename tests/{ => test_legacy}/test_tensor/core/test_dist_spec_mgr.py (91%) rename tests/{ => test_legacy}/test_tensor/test_parameter.py (82%) rename tests/{ => test_legacy}/test_utils/test_activation_checkpointing.py (94%) rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_1d.py (83%) rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_2d.py (83%) rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_2p5d.py (84%) rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_3d.py (83%) rename tests/{ => test_legacy}/test_utils/test_memory.py (76%) rename tests/{ => test_legacy}/test_utils/test_norm_gradient_clipping.py (91%) rename tests/{test_utils => test_legacy/test_zero}/test_commons.py (82%) delete mode 100644 tests/test_utils/test_zero_gradient_clippling.py delete mode 100644 tests/test_zero/test_low_level/test_zero_tp.py diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index a3df2c50e6d3..f1e7a2d0cab0 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -89,7 +89,7 @@ jobs: - name: Install ColossalAI run: | source activate pytorch - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the Doc run: | diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml index 6b4f5d1f908c..027fbfd0aaeb 100644 --- a/.github/workflows/doc_test_on_schedule.yml +++ b/.github/workflows/doc_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: - name: Install ColossalAI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Install Doc Test Requirements run: | diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 620d4771af55..9d3bd9a48235 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -53,7 +53,7 @@ jobs: uses: actions/checkout@v3 - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the example run: | dir=${{ matrix.directory }} diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index ec23b9d1c59f..5934704f4102 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -88,7 +88,7 @@ jobs: - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the example run: | diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index bd52ca4321a2..5ed128c3ebc5 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -42,7 +42,7 @@ jobs: - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Traverse all files run: | diff --git a/colossalai/__init__.py b/colossalai/__init__.py index f859161f7810..fa6f72a605c0 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,11 +1,4 @@ -from .initialize import ( - get_default_parser, - initialize, - launch, - launch_from_openmpi, - launch_from_slurm, - launch_from_torch, -) +from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch try: # .version will be created by setup.py @@ -15,3 +8,5 @@ # and directly set PYTHONPATH to use Colossal-AI which is a bad practice __version__ = '0.0.0' print('please install Colossal-AI from https://www.colossalai.org/download or from source') + +__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__'] diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 963215476b6b..e69de29bb2d1 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -1,54 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer - -from colossalai.context import Config - -from .amp_type import AMP_TYPE -from .apex_amp import convert_to_apex_amp -from .naive_amp import convert_to_naive_amp -from .torch_amp import convert_to_torch_amp - -__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] - - -def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): - """A helper function to wrap training components with Torch AMP modules. - - Args: - param model (:class:`torch.nn.Module`): your model object. - optimizer (:class:`torch.optim.Optimizer`): your optimizer object. - criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. - mode (:class:`colossalai.amp.AMP_TYPE`): amp mode. - amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. - - Returns: - A tuple (model, optimizer, criterion). - - Note: - ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode - for more details about ``amp_config``. - For ``apex_amp``, please check - `apex_amp config `_. - For ``naive_amp``, please check - `naive_amp config `_. - For ``torch_amp``, please check - `torch_amp config `_. - """ - assert isinstance(mode, AMP_TYPE), \ - f'expected the argument mode be AMP_TYPE, but got {type(mode)}' - - if amp_config is None: - amp_config = Config() - - if mode == AMP_TYPE.TORCH: - model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) - elif mode == AMP_TYPE.APEX: - model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) - elif mode == AMP_TYPE.NAIVE: - model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) - - return model, optimizer, criterion diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index 5b2f71d3ced7..e69de29bb2d1 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -1,60 +0,0 @@ -import inspect - -import torch.nn as nn -from torch.optim import Optimizer - -from colossalai.utils import is_no_pp_or_last_stage - -from ._fp16_optimizer import FP16Optimizer -from .grad_scaler import ConstantGradScaler, DynamicGradScaler -from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer - - -def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): - """A helper function to wrap training components with naive AMP modules. In this mode, - we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, - which is equivalent to Apex O3. - - Args: - model (:class:`torch.nn.Module`): your model object - optimizer (:class:`torch.optim.Optimizer`): your optimizer object - amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. - - Returns: - Tuple: A tuple (model, optimizer) - - The ``amp_config`` should contain parameters below:: - - verbose (bool, optional): if set to `True`, will print debug info (Default: False). - clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). - Note that clipping is ignored if clip_grad == 0. - dynamic_grad_scale (bool): whether to use dynamic grad scaler. - """ - if isinstance(model, nn.ModuleList): - # interleaved pipeline - module_list = [] - for chunk, m in enumerate(model): - output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 - module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) - model = nn.ModuleList(module_list) - else: - output_to_fp32 = is_no_pp_or_last_stage() - model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) - - use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) - if use_dynamic_grad_scaler: - scaler_class = DynamicGradScaler - else: - scaler_class = ConstantGradScaler - - sig = inspect.signature(scaler_class.__init__) - kwargs = dict() - for param in sig.parameters.values(): - if param.name in amp_config: - kwargs[param.name] = amp_config.pop(param.name) - grad_scaler = scaler_class(**kwargs) - optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) - return model, optimizer - - -__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 19d85b80dd3d..353133bd6f2d 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -5,8 +5,8 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule @@ -19,7 +19,7 @@ class OptimState(Enum): UNSCALED = 1 -class AMPOptimizer(ColossalaiOptimizer): +class AMPOptimizer(OptimizerWrapper): """ A wrapper for Optimizer. Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 3441eca38ce7..664ac63e45ac 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -13,7 +13,6 @@ from torch.optim import Optimizer from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -130,10 +129,7 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): This method should be used before saving/loading it to/from sharded checkpoints. ''' - # TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future unwrapped_optim = optimizer.optim - if isinstance(unwrapped_optim, ColossalaiOptimizer): - unwrapped_optim = unwrapped_optim.optim return unwrapped_optim diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py deleted file mode 100644 index 618ff8c61dd4..000000000000 --- a/colossalai/cli/benchmark/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -import click - -from colossalai.context import Config - -from .benchmark import run_benchmark -from .utils import * - -__all__ = ['benchmark'] - - -@click.command() -@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.") -@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.") -@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.") -@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.") -@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.") -@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.") -@click.option("-l", "--layers", type=int, default=2) -@click.option("-m", - "--model", - type=click.Choice(['mlp'], case_sensitive=False), - default='mlp', - help="Select the model to benchmark, currently only supports MLP") -def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int, - layers: int, model: str): - args_dict = locals() - args = Config(args_dict) - run_benchmark(args) diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py deleted file mode 100644 index 97a9f45722dd..000000000000 --- a/colossalai/cli/benchmark/benchmark.py +++ /dev/null @@ -1,105 +0,0 @@ -from functools import partial -from typing import Dict, List - -import click -import torch.multiprocessing as mp - -import colossalai -from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model -from colossalai.context import Config -from colossalai.context.random import reset_seeds -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.testing import free_port -from colossalai.utils import MultiTimer - -from .models import MLP - - -def run_benchmark(args: Config) -> None: - """ - Run benchmarking with torch.multiprocessing. - """ - - # sanity checks - if args.gpus is None: - click.echo("Error: --num_gpus is not given") - exit() - if args.gpus <= 1: - click.echo("Warning: tensor parallel will be activated with at least 2 devices.") - - click.echo("=== Benchmarking Parameters ===") - for k, v in args.items(): - click.echo(f'{k}: {v}') - click.echo('') - - config_list = find_all_configs(args.gpus) - - avail_ports = [free_port() for _ in range(len(config_list))] - run_func = partial(run_dist_profiling, - world_size=args.gpus, - port_list=avail_ports, - config_list=config_list, - hyperparams=args) - mp.spawn(run_func, nprocs=args.gpus) - - -def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], - hyperparams: Config) -> None: - """ - A function executed for profiling, this function should be spawn by torch.multiprocessing. - - Args: - rank (int): rank of the process - world_size (int): the number of processes - port_list (List[int]): a list of free ports for initializing distributed networks - config_list (List[Dict]): a list of configuration - hyperparams (Config): the hyperparameters given by the user - - """ - - # disable logging for clean output - disable_existing_loggers() - logger = get_dist_logger() - logger.set_level('WARNING') - - for config, port in zip(config_list, port_list): - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - timer = MultiTimer() - - # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size. - if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0: - click.echo( - "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size." - ) - continue - - if hyperparams.model == 'mlp': - model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) - else: - if gpc.get_global_rank() == 0: - click.echo("Error: Invalid argument for --model") - exit() - - data_func = partial(get_batch_data, - dim=hyperparams.dimension, - batch_size=hyperparams.batch_size, - seq_length=hyperparams.seq_len, - mode=config.parallel.tensor.mode) - - fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, - warmup_steps=hyperparams.warmup_steps, - profile_steps=hyperparams.profile_steps, - data_func=data_func, - timer=timer) - - gpc.destroy() - reset_seeds() - - if gpc.get_global_rank() == 0: - config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) - click.echo(f"=== {config_str} ===") - click.echo(f"Average forward time: {fwd_time}") - click.echo(f"Average backward time: {bwd_time}") - click.echo(f"Max allocated GPU memory: {max_allocated}") - click.echo(f"Max cached GPU memory: {max_cached}\n") diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py deleted file mode 100644 index 385b485b6016..000000000000 --- a/colossalai/cli/benchmark/models.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch - -import colossalai.legacy.nn as col_nn - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int, layers: int): - super().__init__() - self.layers = torch.nn.ModuleList() - - for _ in range(layers): - self.layers.append(col_nn.Linear(dim, dim)) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py deleted file mode 100644 index ee7d92d6ea6a..000000000000 --- a/colossalai/cli/benchmark/utils.py +++ /dev/null @@ -1,159 +0,0 @@ -import math -import time -from typing import Callable, Dict, List, Tuple - -import torch - -from colossalai.context import Config, ParallelMode -from colossalai.utils import MultiTimer - - -def get_time_stamp() -> int: - """ - Return the time stamp for profiling. - - Returns: - time_stamp (int): the time given by time.time() - """ - - torch.cuda.synchronize() - time_stamp = time.time() - return time_stamp - - -def get_memory_states() -> Tuple[float]: - """ - Return the memory statistics. - - Returns: - max_allocated (float): the allocated CUDA memory - max_cached (float): the cached CUDA memory - """ - - max_allocated = torch.cuda.max_memory_allocated() / (1024**3) - max_cached = torch.cuda.max_memory_reserved() / (1024**3) - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - return max_allocated, max_cached - - -def find_all_configs(device_cnt: int) -> List[Dict]: - """ - Find all possible configurations for tensor parallelism - - Args: - device_cnt (int): the number of devices - - Returns: - config_list (List[Dict]): a list of configurations - """ - - def _is_square(num): - # 2D parallel should be implemented with at least 2 devices. - if num <= 1: - return False - return math.floor(math.sqrt(num))**2 == num - - def _is_cube(num): - # 3D parallel should be implemented with at least 2 devices. - if num <= 1: - return False - return math.floor(num**(1. / 3.))**3 == num - - config_list = [] - - # add non-parallel config - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) - config_list.append(config) - - # add 1D config - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) - config_list.append(config) - - # add 2D config only if device_cnt is a square - if _is_square(device_cnt): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) - config_list.append(config) - - # check for 2.5D - # iterate over depth - for depth in range(1, device_cnt): - if device_cnt % depth == 0 and _is_square(device_cnt // depth): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) - config_list.append(config) - - # check for 3D if device_cnt is a cube - if _is_cube(device_cnt): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) - config_list.append(config) - - config_list = [Config(cfg) for cfg in config_list] - return config_list - - -def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, - timer: MultiTimer) -> Tuple[float]: - """ - Profile the forward and backward of a model - - Args: - model (torch.nn.Module): a PyTorch model - warmup_steps (int): the number of steps for warmup - profile_steps (int): the number of steps for profiling - data_func (Callable): a function to generate random data - timer (colossalai.utils.Multitimer): a timer instance for time recording - - Returns: - fwd_time (float): the average forward time taken by forward pass in second - bwd_time (float): the average backward time taken by forward pass in second - max_allocated (float): the maximum GPU memory allocated in GB - max_cached (float): the maximum GPU memory cached in GB - """ - - def _run_step(data): - timer.start('forward') - out = model(data) - timer.stop('forward', keep_in_history=True) - timer.start('backward') - out.mean().backward() - timer.stop('backward', keep_in_history=True) - - data_list = [data_func() for _ in range(warmup_steps)] - for data in data_list: - _run_step(data) - timer.reset('forward') - timer.reset('backward') - - for _ in range(profile_steps): - data = data_func() - _run_step(data) - - max_allocated, max_cached = get_memory_states() - fwd_time = timer.get_timer('forward').get_history_mean() - bwd_time = timer.get_timer('backward').get_history_mean() - return fwd_time, bwd_time, max_allocated, max_cached - - -def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: - """ - Return a random data of shape (batch_size, seq_length, dim) for profiling. - - Args: - dim (int): hidden size - batch_size (int): the number of data samples - seq_length (int): the number of tokens - mode (ParallelMode): Colossal-AI ParallelMode enum - - Returns: - data (torch.Tensor): random data - """ - - if mode in ['2d', '2.5d']: - batch_size = batch_size // 2 - dim = dim // 2 - elif mode == '3d': - batch_size = batch_size // 4 - dim = dim // 2 - - data = torch.rand(batch_size, seq_length, dim).cuda() - return data diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index a94e1150e49f..0dea7c504957 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -1,6 +1,5 @@ import click -from .benchmark import benchmark from .check import check from .launcher import run @@ -19,7 +18,6 @@ def cli(): cli.add_command(run) cli.add_command(check) -cli.add_command(benchmark) if __name__ == '__main__': cli() diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index 50178b5fa850..eb6d5d05a008 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,6 +1,8 @@ from .config import Config, ConfigException -from .parallel_context import ParallelContext -from .parallel_mode import ParallelMode -from .moe_context import MOE_CONTEXT -from .process_group_initializer import * -from .random import * + +# from .moe_context import MOE_CONTEXT + +__all__ = [ + 'Config', + 'ConfigException', +] diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b41f4072a405..b6e3b52017b2 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,13 +3,12 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup def _check_sanity(): - from colossalai.core import global_context as gpc + from colossalai.legacy.core import global_context as gpc if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") @@ -61,7 +60,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True): self.world_size = dist.get_world_size() - from colossalai.core import global_context as gpc + from colossalai.legacy.core import global_context as gpc self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) assert self.world_size % self.max_ep_size == 0, \ "Maximum expert parallel size must be a factor of the number of GPUs" diff --git a/colossalai/core.py b/colossalai/core.py deleted file mode 100644 index 153247bbed9c..000000000000 --- a/colossalai/core.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from colossalai.context.parallel_context import global_context - -__all__ = ['global_context'] \ No newline at end of file diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index d2bad06bb45a..ccbab0c38a29 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,9 +1,11 @@ +import operator + import torch import torch.nn as nn -import operator -from colossalai.tensor import ProcessGroup -from colossalai.tensor.distspec import ShardSpec -from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +from colossalai.legacy.tensor import ProcessGroup +from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec +from colossalai.legacy.tensor.distspec import ShardSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ @@ -13,7 +15,7 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: - """weight_split + """weight_split split a nn.Parameter Args: @@ -60,9 +62,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): """ - This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. """ - #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + # TODO: Needs to handle special cases, like x = linear(x) + linear(x) graph = graph_module.graph world_size = process_group.world_size() diff --git a/colossalai/initialize.py b/colossalai/initialize.py index a1694e059fb4..b8718abc80bd 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -1,58 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import argparse import os -import pprint +import warnings from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Union import torch -import torch.nn as nn -from torch.nn.modules.loss import _Loss -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import _LRScheduler -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader +import torch.distributed as dist -from colossalai.amp import AMP_TYPE, convert_to_amp -from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.context import Config, ConfigException, ParallelMode -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.legacy.builder.builder import build_gradient_handler -from colossalai.legacy.engine import Engine -from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient -from colossalai.legacy.engine.schedule import ( - InterleavedPipelineSchedule, - NonPipelineSchedule, - PipelineSchedule, - get_tensor_shape, -) +from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer -from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param -from colossalai.utils.moe import sync_moe_model_param -from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 -from colossalai.zero.legacy.gemini.ophooks import BaseOpHook - - -def get_default_parser(): - """Reads user command line and uses an argument parser to parse the input arguments. - Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. - - Returns: - Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. - """ - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='path to the config file') - parser.add_argument('--host', type=str, help='the master address for distributed training') - parser.add_argument('--port', type=int, help='the master port for distributed training') - parser.add_argument('--world_size', type=int, help='world size for distributed training') - parser.add_argument('--rank', type=int, help='rank for the default process group') - parser.add_argument('--local_rank', type=int, help='local rank on the node') - parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') - return parser +from colossalai.utils import set_device, set_seed def launch(config: Union[str, Path, Config, Dict], @@ -83,40 +42,23 @@ def launch(config: Union[str, Path, Config, Dict], Raises: Exception: Raise exception when config type is wrong """ - gpc.verbose = verbose - - # set config - assert isinstance(config, (Config, str, Path, dict)), \ - f'expected argument config to be Config, str or Path, but got {type(config)}' - if not isinstance(config, Config) and isinstance(config, dict): - config = Config(config) - if isinstance(config, (str, Path)): - config = Config.from_file(config) - gpc.load_config(config) + if rank == 0: + warnings.warn("`config` is deprecated and will be removed soon.") # init default process group - gpc.init_global_dist(rank, world_size, backend, host, port) - - # init process groups for different parallel modes from config - gpc.init_parallel_groups() + init_method = f'tcp://[{host}]:{port}' + dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device if torch.cuda.is_available(): # if local rank is not given, calculate automatically - gpc.set_device(local_rank) - - # set the number of processes running on the same node - gpc.detect_num_processes_on_current_node() + set_device(local_rank) - gpc.set_seed(seed) + set_seed(seed) if verbose: logger = get_dist_logger() - logger.info( - f'Distributed environment is initialized, ' - f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' - f'tensor parallel size: {gpc.tensor_parallel_size}', - ranks=[0]) + logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0]) def launch_from_slurm(config: Union[str, Path, Config, Dict], @@ -224,247 +166,3 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend=backend, seed=seed, verbose=verbose) - - -def initialize(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, - lr_scheduler: Optional[_LRScheduler] = None, - ophooks: Optional[List[BaseOpHook]] = None, - verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: - """Core function to wrap the essential training components with our functionality based on the config which is - loaded into gpc.config. - - Args: - model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. - optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): - Your optimizer instance. - criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. - train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. - test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. - lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. - verbose (bool, optional): Whether to print logs. - - Returns: - Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): - A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` - where only ``engine`` could not be None. - """ - # get logger - logger = get_dist_logger() - gpc.verbose = verbose - - # get config from gpc - config = gpc.config - - # print config - if verbose: - logger.info( - f"\n========== Your Config ========\n" - f"{pprint.pformat(gpc.config)}\n" - f"================================\n", - ranks=[0]) - - # cudnn - cudnn_benchmark = config.get('cudnn_benchmark', False) - cudnn_deterministic = config.get('cudnn_deterministic', False) - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.backends.cudnn.deterministic = cudnn_deterministic - if verbose: - logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) - - # zero - use_zero = hasattr(gpc.config, 'zero') - if use_zero: - zero_cfg = gpc.config.get('zero', None) - if zero_cfg is not None: - cfg_ = zero_cfg.copy() - else: - cfg_ = {} - optimizer_config = zero_cfg.get('optimizer_config', None) - model_config = zero_cfg.get('model_config', None) - model, optimizer = convert_to_zero_v2(model, - optimizer, - model_config=model_config, - optimizer_config=optimizer_config) - - logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) - else: - if isinstance(model, nn.Module): - # first sync model across dp ranks - model.to(get_current_device()) - elif isinstance(model, Callable): - model = model().to(get_current_device()) - - # optimizer maybe a optimizer_cls - if isinstance(optimizer, Callable): - optimizer = optimizer(model.parameters()) - logger.warning("Initializing an non ZeRO model with optimizer class") - - if not use_zero: - if is_using_sequence(): - sync_model_param(model, ParallelMode.SEQUENCE_DP) - elif MOE_CONTEXT.is_initialized: - sync_moe_model_param(model) - elif is_using_ddp(): - sync_model_param(model, ParallelMode.DATA) - else: - logger.warning( - "The parameters of models is not automatically synchronized.\n" - "Please make sure that all parameters are the same in data parallel group.", - ranks=[0]) - - # check amp and zero - fp16_cfg = gpc.config.get('fp16', None) - - if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: - raise ConfigException( - "It is not allowed to set fp16 and zero configuration in your config file at the same time") - - # clip grad norm - clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) - - # initialize amp - amp_mode = None - if fp16_cfg is not None and fp16_cfg.mode is not None: - cfg_ = fp16_cfg.copy() - amp_mode = cfg_.pop('mode') - if is_using_pp(): - assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' - if amp_mode == AMP_TYPE.NAIVE: - cfg_['clip_grad_norm'] = clip_grad_norm - model, optimizer, criterion = convert_to_amp(model=model, - optimizer=optimizer, - criterion=criterion, - mode=amp_mode, - amp_config=cfg_) - - # get torch ddp config - torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) - - # gradient handler - gradient_handler_cfg = gpc.config.get('gradient_handler', None) - if gradient_handler_cfg is None: - # if gradient handler is not specified in the configuration file, - # check in the following order - # 1. if optimizer is ZERO, then use zero grad handler - # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp - # 3. if using pipeline and dp size larger than 1, use data parallel grad handler - if isinstance(optimizer, ShardedOptimizerV2): - gradient_handler_cfg = [dict(type='ZeROGradientHandler')] - if verbose: - logger.info( - "Training with zero is detected, ZeROGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type='MoeGradientHandler')] - if verbose: - logger.info( - "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - elif is_using_sequence(): - model = DDP(model, - process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) - if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', - ranks=[0]) - elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, - process_group=gpc.get_group(ParallelMode.DATA), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) - if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) - elif is_using_ddp(): - gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] - if verbose: - logger.info( - "Data parallel training is detected when using pipeline parallel, " - "DataParallelGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - # add pipeline parallel gradient handler, if pipeline shared module is detected - for param in model.parameters(): - if getattr(param, 'pipeline_shared_module_pg', None) is not None: - if gradient_handler_cfg is None: - gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] - else: - gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) - if verbose: - logger.info( - "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - break - else: - if not isinstance(gradient_handler_cfg, list): - raise ConfigException( - f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" - ) - - # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time - # to avoid duplicated buffer synchronization - if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): - model.module.sync_buffer = False - - # initialize schedule for engine - if is_using_pp(): - tensor_shape = get_tensor_shape() - use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') - if gpc.is_initialized(ParallelMode.PARALLEL_1D): - scatter_gather = True - else: - scatter_gather = False - if use_interleaved: - if isinstance(model, nn.Sequential): - model = nn.ModuleList([model]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) - else: - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) - else: - schedule = NonPipelineSchedule() - - if gradient_handler_cfg is None: - gradient_handlers = None - if verbose and not isinstance(model, DDP): - logger.warning( - "No PyTorch DDP or gradient handler is set up, please make sure you do not need " - "to all-reduce the gradients after a training step.", - ranks=[0]) - else: - gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] - - # check if optimizer is ColossalaiOptimizer - if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)): - optimizer = ColossalaiOptimizer(optim=optimizer) - - # gradient accumulation - grad_accum_size = gpc.config.get('gradient_accumulation', None) - if grad_accum_size is not None: - optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( - model=model, - optimizer=optimizer, - dataloader=train_dataloader, - accumulate_size=grad_accum_size, - gradient_handlers=gradient_handlers, - lr_scheduler=lr_scheduler) - engine = Engine(model=model, - optimizer=optimizer, - criterion=criterion, - gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm, - ophook_list=ophooks, - schedule=schedule) - - return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py index e69de29bb2d1..f51941ee800b 100644 --- a/colossalai/legacy/__init__.py +++ b/colossalai/legacy/__init__.py @@ -0,0 +1,9 @@ +from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch + +__all__ = [ + 'launch', + 'launch_from_openmpi', + 'launch_from_slurm', + 'launch_from_torch', + 'initialize', +] diff --git a/colossalai/legacy/amp/__init__.py b/colossalai/legacy/amp/__init__.py new file mode 100644 index 000000000000..e83a7f6ac5cd --- /dev/null +++ b/colossalai/legacy/amp/__init__.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .amp_type import AMP_TYPE +from .apex_amp import convert_to_apex_amp +from .naive_amp import convert_to_naive_amp +from .torch_amp import convert_to_torch_amp + +__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] + + +def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): + """A helper function to wrap training components with Torch AMP modules. + + Args: + param model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. + mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. + + Returns: + A tuple (model, optimizer, criterion). + + Note: + ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode + for more details about ``amp_config``. + For ``apex_amp``, please check + `apex_amp config `_. + For ``naive_amp``, please check + `naive_amp config `_. + For ``torch_amp``, please check + `torch_amp config `_. + """ + assert isinstance(mode, AMP_TYPE), \ + f'expected the argument mode be AMP_TYPE, but got {type(mode)}' + + if amp_config is None: + amp_config = Config() + + if mode == AMP_TYPE.TORCH: + model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) + elif mode == AMP_TYPE.APEX: + model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) + elif mode == AMP_TYPE.NAIVE: + model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) + + return model, optimizer, criterion diff --git a/colossalai/amp/amp_type.py b/colossalai/legacy/amp/amp_type.py similarity index 100% rename from colossalai/amp/amp_type.py rename to colossalai/legacy/amp/amp_type.py diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/legacy/amp/apex_amp/__init__.py similarity index 100% rename from colossalai/amp/apex_amp/__init__.py rename to colossalai/legacy/amp/apex_amp/__init__.py diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/legacy/amp/apex_amp/apex_amp.py similarity index 86% rename from colossalai/amp/apex_amp/apex_amp.py rename to colossalai/legacy/amp/apex_amp/apex_amp.py index e6bdbe4520f9..acc051181562 100644 --- a/colossalai/amp/apex_amp/apex_amp.py +++ b/colossalai/legacy/amp/apex_amp/apex_amp.py @@ -10,11 +10,11 @@ from torch import Tensor -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import clip_grad_norm_fp32 +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.utils import clip_grad_norm_fp32 -class ApexAMPOptimizer(ColossalaiOptimizer): +class ApexAMPOptimizer(OptimizerWrapper): """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods """ diff --git a/colossalai/legacy/amp/naive_amp/__init__.py b/colossalai/legacy/amp/naive_amp/__init__.py new file mode 100644 index 000000000000..2ee84fc763b1 --- /dev/null +++ b/colossalai/legacy/amp/naive_amp/__init__.py @@ -0,0 +1,60 @@ +import inspect + +import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler +from colossalai.legacy.utils import is_no_pp_or_last_stage + +from ._fp16_optimizer import FP16Optimizer +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer + + +def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): + """A helper function to wrap training components with naive AMP modules. In this mode, + we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, + which is equivalent to Apex O3. + + Args: + model (:class:`torch.nn.Module`): your model object + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. + + Returns: + Tuple: A tuple (model, optimizer) + + The ``amp_config`` should contain parameters below:: + + verbose (bool, optional): if set to `True`, will print debug info (Default: False). + clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). + Note that clipping is ignored if clip_grad == 0. + dynamic_grad_scale (bool): whether to use dynamic grad scaler. + """ + if isinstance(model, nn.ModuleList): + # interleaved pipeline + module_list = [] + for chunk, m in enumerate(model): + output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 + module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) + model = nn.ModuleList(module_list) + else: + output_to_fp32 = is_no_pp_or_last_stage() + model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) + + use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + if use_dynamic_grad_scaler: + scaler_class = DynamicGradScaler + else: + scaler_class = ConstantGradScaler + + sig = inspect.signature(scaler_class.__init__) + kwargs = dict() + for param in sig.parameters.values(): + if param.name in amp_config: + kwargs[param.name] = amp_config.pop(param.name) + grad_scaler = scaler_class(**kwargs) + optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) + return model, optimizer + + +__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py similarity index 97% rename from colossalai/amp/naive_amp/_fp16_optimizer.py rename to colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index e4699f92b944..2733477599f7 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -6,14 +6,15 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes from colossalai.logging import get_dist_logger -from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier +from colossalai.utils import multi_tensor_applier from ._utils import has_inf_or_nan, zero_gard_by_list -from .grad_scaler import BaseGradScaler try: from colossalai._C import fused_optim diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/legacy/amp/naive_amp/_utils.py similarity index 100% rename from colossalai/amp/naive_amp/_utils.py rename to colossalai/legacy/amp/naive_amp/_utils.py diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/legacy/amp/naive_amp/naive_amp.py similarity index 94% rename from colossalai/amp/naive_amp/naive_amp.py rename to colossalai/legacy/amp/naive_amp/naive_amp.py index 6a39d518d3f4..1fab3e5a0d0d 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/legacy/amp/naive_amp/naive_amp.py @@ -11,14 +11,14 @@ from torch.distributed import ReduceOp from torch.optim import Optimizer -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from ._fp16_optimizer import FP16Optimizer -class NaiveAMPOptimizer(ColossalaiOptimizer): +class NaiveAMPOptimizer(OptimizerWrapper): """A wrapper class for optimizer to cast all parameters to fp16 Args: @@ -57,7 +57,7 @@ class NaiveAMPModel(nn.Module): Args: model (torch.nn.Module): torch.nn.Module to be wrapped. output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this module. (Default: ``ParallelMode.DATA``) sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/legacy/amp/torch_amp/__init__.py similarity index 100% rename from colossalai/amp/torch_amp/__init__.py rename to colossalai/legacy/amp/torch_amp/__init__.py diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/legacy/amp/torch_amp/_grad_scaler.py similarity index 99% rename from colossalai/amp/torch_amp/_grad_scaler.py rename to colossalai/legacy/amp/torch_amp/_grad_scaler.py index ed4b8e484436..543dac6ab5ef 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/legacy/amp/torch_amp/_grad_scaler.py @@ -13,8 +13,8 @@ from packaging import version from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc class _MultiDeviceReplicator(object): diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py similarity index 95% rename from colossalai/amp/torch_amp/torch_amp.py rename to colossalai/legacy/amp/torch_amp/torch_amp.py index 65718d77c2e0..c45a5956a205 100644 --- a/colossalai/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -7,13 +7,13 @@ from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import clip_grad_norm_fp32 +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler -class TorchAMPOptimizer(ColossalaiOptimizer): +class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer Args: diff --git a/colossalai/legacy/communication/collective.py b/colossalai/legacy/communication/collective.py index 64fb5b8b5296..7471188226f0 100644 --- a/colossalai/legacy/communication/collective.py +++ b/colossalai/legacy/communication/collective.py @@ -6,8 +6,8 @@ from torch import Tensor from torch.distributed import ReduceOp -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc _all_gather_func = dist._all_gather_base \ if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor @@ -26,7 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: Args: tensor (:class:`torch.Tensor`): Tensor to be gathered. dim (int): The dimension concatenating in. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -65,7 +65,7 @@ def reduce_scatter(tensor: Tensor, Args: tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered. dim (int): The dimension concatenating in. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. op (torch.distributed.ReduceOp, optional): The type of reduce operation, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to @@ -105,7 +105,7 @@ def all_reduce(tensor: Tensor, Args: tensor (:class:`torch.Tensor`): Tensor to be all-reduced. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. op (torch.distributed.ReduceOp, optional): The type of reduce operation, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to @@ -141,7 +141,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b Args: tensor (:class:`torch.Tensor`): Tensor to be broadcast. src (int): Source rank. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -173,7 +173,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = Args: tensor (:class:`torch.Tensor`): Tensor to be reduced. dst (int): Destination rank. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index d28d140168fd..e3f9108ab840 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -8,8 +8,8 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.utils import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks diff --git a/colossalai/legacy/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py index 090311cb35f2..66af214950f2 100644 --- a/colossalai/legacy/communication/p2p_v2.py +++ b/colossalai/legacy/communication/p2p_v2.py @@ -10,8 +10,8 @@ from torch.distributed import ProcessGroupNCCL from torch.distributed import distributed_c10d as c10d -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc TensorShape = Union[torch.Size, List[int], Tuple[int]] _pg_manager = {} diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index aece7574b7c4..e80192fb578d 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -3,8 +3,8 @@ import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.utils import get_current_device, synchronize diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 1516df356278..7e3dcf1e9820 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -3,8 +3,8 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.utils import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] diff --git a/colossalai/constants.py b/colossalai/legacy/constants.py similarity index 100% rename from colossalai/constants.py rename to colossalai/legacy/constants.py diff --git a/colossalai/legacy/context/__init__.py b/colossalai/legacy/context/__init__.py new file mode 100644 index 000000000000..7027945ead7c --- /dev/null +++ b/colossalai/legacy/context/__init__.py @@ -0,0 +1,4 @@ +from .parallel_context import ParallelContext +from .parallel_mode import ParallelMode +from .process_group_initializer import * +from .random import * diff --git a/colossalai/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py similarity index 88% rename from colossalai/context/parallel_context.py rename to colossalai/legacy/context/parallel_context.py index 7186f052ecec..8fdc3d6fea68 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -11,10 +11,10 @@ import torch import torch.distributed as dist -from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING from colossalai.context.config import Config from colossalai.context.singleton_meta import SingletonMeta -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.constants import ALLOWED_MODES, INITIALIZER_MAPPING +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from colossalai.logging import get_dist_logger @@ -110,12 +110,12 @@ def add_global_rank(self, parallel_mode: ParallelMode, rank: int): """Adds the global rank of the current device for `parallel_mode` to the context. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank. rank (int): The rank to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._global_ranks[parallel_mode] = rank @@ -124,11 +124,11 @@ def get_local_rank(self, parallel_mode: ParallelMode): """Returns the local rank of the current device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The local rank of the current device for `parallel_mode`. @@ -140,12 +140,12 @@ def _add_local_rank(self, parallel_mode: ParallelMode, rank: int): """Adds the local rank of the current device for `parallel_mode` to the context. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank. rank (int): The rank to be added. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._local_ranks[parallel_mode] = rank @@ -154,11 +154,11 @@ def get_next_global_rank(self, parallel_mode: ParallelMode): """Returns the global rank of the next device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The global rank of the next device for `parallel_mode`. @@ -176,11 +176,11 @@ def get_prev_global_rank(self, parallel_mode: ParallelMode): """Returns the global rank of the previous device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The global rank of the previous device for `parallel_mode`. @@ -199,11 +199,11 @@ def is_first_rank(self, parallel_mode: ParallelMode): among its group for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: bool: a boolean value indicating whether the current device is the first one @@ -217,11 +217,11 @@ def is_last_rank(self, parallel_mode: ParallelMode): among its group for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: bool: a boolean value indicating whether the current device is the first one @@ -248,11 +248,11 @@ def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The world size for `parallel_mode`. @@ -264,12 +264,12 @@ def _add_world_size(self, parallel_mode: ParallelMode, world_size: int): """Adds world size for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode corresponding to the process group + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode corresponding to the process group world_size (int): The world size to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._world_sizes[parallel_mode] = world_size @@ -278,11 +278,11 @@ def get_group(self, parallel_mode: ParallelMode): """Returns the group of the current device for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. @@ -294,12 +294,12 @@ def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): """Adds the group of the current device for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. group (torch.distributed.ProcessGroup): The group to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._groups[parallel_mode] = group @@ -308,9 +308,9 @@ def get_cpu_group(self, parallel_mode: ParallelMode): """Returns the Gloo group of the current device for `parallel_mode`. :param parallel_mode: The chosen parallel mode - :type parallel_mode: :class:`colossalai.context.ParallelMode` + :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode` :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode` + of :class:`colossalai.legacy.context.ParallelMode` :return: The group of the current device for `parallel_mode` :rtype: torch.distributed.ProcessGroup """ @@ -321,11 +321,11 @@ def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): """Adds the Gloo group of the current device for `parallel_mode`. :param parallel_mode: The chosen parallel mode - :type parallel_mode: :class:`colossalai.context.ParallelMode` + :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode` :param group: The group to be added :type group: torch.distributed.ProcessGroup :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode` + of :class:`colossalai.legacy.context.ParallelMode` """ self._check_parallel_mode(parallel_mode) self._cpu_groups[parallel_mode] = group @@ -334,11 +334,11 @@ def get_ranks_in_group(self, parallel_mode: ParallelMode): """Returns the rank of the current device for `parallel_mode` in the group. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The rank of the current device for `parallel_mode` in the group. @@ -350,12 +350,12 @@ def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): """Adds the ranks of the current device for `parallel_mode` in the group. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. ranks (list): List of ranks to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._ranks_in_group[parallel_mode] = ranks @@ -489,7 +489,7 @@ def is_initialized(self, parallel_mode: ParallelMode): in the current system. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Returns: bool: a boolean value indicating whether `parallel_mode` is initialized in the current system. diff --git a/colossalai/context/parallel_mode.py b/colossalai/legacy/context/parallel_mode.py similarity index 100% rename from colossalai/context/parallel_mode.py rename to colossalai/legacy/context/parallel_mode.py diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/legacy/context/process_group_initializer/__init__.py similarity index 100% rename from colossalai/context/process_group_initializer/__init__.py rename to colossalai/legacy/context/process_group_initializer/__init__.py index d3937a947437..48d52d7b9e52 100644 --- a/colossalai/context/process_group_initializer/__init__.py +++ b/colossalai/legacy/context/process_group_initializer/__init__.py @@ -3,10 +3,10 @@ from .initializer_2p5d import Initializer_2p5D from .initializer_3d import Initializer_3D from .initializer_data import Initializer_Data +from .initializer_model import Initializer_Model from .initializer_pipeline import Initializer_Pipeline from .initializer_sequence import Initializer_Sequence from .initializer_tensor import Initializer_Tensor -from .initializer_model import Initializer_Model from .process_group_initializer import ProcessGroupInitializer __all__ = [ diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/legacy/context/process_group_initializer/initializer_1d.py similarity index 96% rename from colossalai/context/process_group_initializer/initializer_1d.py rename to colossalai/legacy/context/process_group_initializer/initializer_1d.py index ba601d0bf61a..d853c6f06fc0 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_1d.py @@ -3,7 +3,7 @@ import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py similarity index 98% rename from colossalai/context/process_group_initializer/initializer_2d.py rename to colossalai/legacy/context/process_group_initializer/initializer_2d.py index 999cd5f0cfc6..39f6a46890b6 100644 --- a/colossalai/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -2,7 +2,7 @@ import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py similarity index 99% rename from colossalai/context/process_group_initializer/initializer_2p5d.py rename to colossalai/legacy/context/process_group_initializer/initializer_2p5d.py index b92ae2eec07e..bb7a3509572f 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py @@ -6,7 +6,7 @@ import torch.distributed as dist from colossalai.context import Config -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/legacy/context/process_group_initializer/initializer_3d.py similarity index 99% rename from colossalai/context/process_group_initializer/initializer_3d.py rename to colossalai/legacy/context/process_group_initializer/initializer_3d.py index 6bca05ad7d5f..3dfbf5223b12 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_3d.py @@ -5,7 +5,7 @@ import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/legacy/context/process_group_initializer/initializer_data.py similarity index 100% rename from colossalai/context/process_group_initializer/initializer_data.py rename to colossalai/legacy/context/process_group_initializer/initializer_data.py diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/legacy/context/process_group_initializer/initializer_model.py similarity index 100% rename from colossalai/context/process_group_initializer/initializer_model.py rename to colossalai/legacy/context/process_group_initializer/initializer_model.py diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py similarity index 100% rename from colossalai/context/process_group_initializer/initializer_pipeline.py rename to colossalai/legacy/context/process_group_initializer/initializer_pipeline.py diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py similarity index 100% rename from colossalai/context/process_group_initializer/initializer_sequence.py rename to colossalai/legacy/context/process_group_initializer/initializer_sequence.py diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py similarity index 100% rename from colossalai/context/process_group_initializer/initializer_tensor.py rename to colossalai/legacy/context/process_group_initializer/initializer_tensor.py diff --git a/colossalai/context/process_group_initializer/process_group_initializer.py b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py similarity index 100% rename from colossalai/context/process_group_initializer/process_group_initializer.py rename to colossalai/legacy/context/process_group_initializer/process_group_initializer.py diff --git a/colossalai/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py similarity index 100% rename from colossalai/context/random/__init__.py rename to colossalai/legacy/context/random/__init__.py diff --git a/colossalai/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py similarity index 90% rename from colossalai/context/random/_helper.py rename to colossalai/legacy/context/random/_helper.py index 973c4d9faa32..4b5d5ef2fe55 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -7,8 +7,8 @@ import torch.cuda from torch import Tensor -from .seed_manager import SeedManager from ..parallel_mode import ParallelMode +from .seed_manager import SeedManager _SEED_MANAGER = SeedManager() @@ -53,11 +53,11 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): """Adds a seed to the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of - :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added. + :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -70,7 +70,7 @@ def set_mode(parallel_mode: ParallelMode): """Sets the current mode of the seed manager. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -83,7 +83,7 @@ def set_seed_states(parallel_mode: ParallelMode, state: Tensor): """Sets the state of the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. state (:class:`torch.Tensor`): the state to be set. Raises: @@ -161,7 +161,7 @@ def wrapper(*args, **kwargs): def moe_set_seed(seed): if torch.cuda.is_available(): - from colossalai.core import global_context as gpc + from colossalai.legacy.core import global_context as gpc global_rank = gpc.get_global_rank() diff_seed = seed + global_rank add_seed(ParallelMode.TENSOR, diff_seed, True) diff --git a/colossalai/context/random/seed_manager.py b/colossalai/legacy/context/random/seed_manager.py similarity index 86% rename from colossalai/context/random/seed_manager.py rename to colossalai/legacy/context/random/seed_manager.py index 956f9001200d..b657ff7e1d32 100644 --- a/colossalai/context/random/seed_manager.py +++ b/colossalai/legacy/context/random/seed_manager.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode class SeedManager: @@ -36,7 +36,7 @@ def set_state(self, parallel_mode: ParallelMode, state: Tensor): """Sets the state of the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. state (:class:`torch.Tensor`): the state to be set. Raises: @@ -49,7 +49,7 @@ def set_mode(self, parallel_mode: ParallelMode): """Sets the current mode of the seed manager. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. """ if self.current_mode: # save the current state for current mode @@ -63,12 +63,12 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal """Adds a seed to the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added. overwrite (bool, optional): Whether allows to overwrite the seed that has been set already Raises: - AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. """ assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' diff --git a/colossalai/legacy/core.py b/colossalai/legacy/core.py new file mode 100644 index 000000000000..0aaf1ee47730 --- /dev/null +++ b/colossalai/legacy/core.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from colossalai.legacy.context.parallel_context import global_context + +__all__ = ['global_context'] diff --git a/colossalai/legacy/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py index 9af4469f403f..930caf20c1dd 100644 --- a/colossalai/legacy/engine/_base_engine.py +++ b/colossalai/legacy/engine/_base_engine.py @@ -8,6 +8,7 @@ from torch.nn import Module from torch.nn.modules.loss import _Loss +from colossalai.interface import OptimizerWrapper from colossalai.legacy.engine.gradient_handler import BaseGradientHandler from colossalai.legacy.engine.schedule import ( BaseSchedule, @@ -15,9 +16,8 @@ NonPipelineSchedule, PipelineSchedule, ) +from colossalai.legacy.zero.gemini import BaseOpHook, register_ophooks_recursively from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively class Engine: @@ -27,7 +27,7 @@ class Engine: Args: model (``torch.nn.Module``): The neural network model. - optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters. + optimizer (``colossalai.interface.OptimizerWrapper``): Optimizer for updating the parameters. criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss. gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward. clip_grad_norm (float, optional): The norm of gradient clipping. @@ -61,7 +61,7 @@ class Engine: def __init__(self, model: Module, - optimizer: "ColossalaiOptimizer", + optimizer: "OptimizerWrapper", criterion: Optional[_Loss] = None, gradient_handlers: Optional[List[BaseGradientHandler]] = None, clip_grad_norm: float = 0.0, @@ -157,7 +157,7 @@ def step(self): """Execute parameter update """ self._all_reduce_gradients() - self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) + self.optimizer.clip_grad_by_norm(self._clip_grad_norm) return self.optimizer.step() def backward(self, loss: Tensor): diff --git a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py index c466f7e2d03b..c2270dc53a50 100644 --- a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py @@ -10,12 +10,12 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader +from colossalai.interface import OptimizerWrapper from colossalai.legacy.engine import BaseGradientHandler -from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import conditional_context -class GradAccumOptimizer(ColossalaiOptimizer): +class GradAccumOptimizer(OptimizerWrapper): """A wrapper for the optimizer to enable gradient accumulation by skipping the steps before accumulation size is reached. @@ -74,7 +74,7 @@ def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None: if self.accumulate_step < self.accumulate_size: pass else: - self.optim.clip_grad_norm(model, max_norm) + self.optim.clip_grad_by_norm(max_norm) def backward(self, loss: Tensor) -> None: """Execute backward pass. diff --git a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py index c5da2e55a0ed..c692ee903442 100644 --- a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -1,5 +1,5 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER from ._base_gradient_handler import BaseGradientHandler diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index 395d83da0478..e7a6df2d8ae8 100644 --- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -1,6 +1,6 @@ from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.utils.moe import get_moe_epsize_param_dict diff --git a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 7d4d9d73afc8..3eae7d58ac95 100644 --- a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -6,7 +6,7 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER from ._base_gradient_handler import BaseGradientHandler diff --git a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py index 41098ab39d0c..38b7f5993b73 100644 --- a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -1,5 +1,5 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER from ._base_gradient_handler import BaseGradientHandler diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 4571fd679e8c..37eed82f8a28 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -7,11 +7,11 @@ import torch.cuda import colossalai.legacy.communication as comm -from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.amp.naive_amp import NaiveAMPModel +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device from ._base_schedule import BaseSchedule @@ -157,7 +157,7 @@ def load_micro_batch(self): return self._move_to_device(micro_batch_data) def pre_processing(self, engine): - from colossalai.zero.legacy import ShardedModelV2 + from colossalai.legacy.zero import ShardedModelV2 # TODO: remove this after testing new zero with pipeline parallelism model = engine.model diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 385c615372f5..bf8b599a81ae 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -6,8 +6,8 @@ import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine from colossalai.utils.cuda import get_current_device diff --git a/colossalai/global_variables.py b/colossalai/legacy/global_variables.py similarity index 100% rename from colossalai/global_variables.py rename to colossalai/legacy/global_variables.py diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py new file mode 100644 index 000000000000..2c253adbaf38 --- /dev/null +++ b/colossalai/legacy/initialize.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import argparse +import os +import pprint +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from colossalai.context import Config, ConfigException +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.amp import AMP_TYPE, convert_to_amp +from colossalai.legacy.amp.naive_amp import NaiveAMPModel +from colossalai.legacy.builder.builder import build_gradient_handler +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.engine import Engine +from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient +from colossalai.legacy.engine.schedule import ( + InterleavedPipelineSchedule, + NonPipelineSchedule, + PipelineSchedule, + get_tensor_shape, +) +from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence, sync_model_param +from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 +from colossalai.legacy.zero.gemini.ophooks import BaseOpHook +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param + + +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, help='path to the config file') + parser.add_argument('--host', type=str, help='the master address for distributed training') + parser.add_argument('--port', type=int, help='the master port for distributed training') + parser.add_argument('--world_size', type=int, help='world size for distributed training') + parser.add_argument('--rank', type=int, help='rank for the default process group') + parser.add_argument('--local_rank', type=int, help='local rank on the node') + parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') + return parser + + +def launch(config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = 'nccl', + local_rank: int = None, + seed: int = 1024, + verbose: bool = True): + """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input + arguments are not given. Then initialize and set distributed environment by calling global_context's functions. + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + rank (int): Rank for the default process group + world_size (int): World size of the default process group + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + local_rank (int, optional): + Rank for the process on the node and is used to set the default CUDA device, + defaults to None. If local_rank = None, the default device ordinal will be calculated automatically. + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + + Raises: + Exception: Raise exception when config type is wrong + """ + gpc.verbose = verbose + + # set config + assert isinstance(config, (Config, str, Path, dict)), \ + f'expected argument config to be Config, str or Path, but got {type(config)}' + if not isinstance(config, Config) and isinstance(config, dict): + config = Config(config) + if isinstance(config, (str, Path)): + config = Config.from_file(config) + gpc.load_config(config) + + # init default process group + gpc.init_global_dist(rank, world_size, backend, host, port) + + # init process groups for different parallel modes from config + gpc.init_parallel_groups() + + # set cuda device + if torch.cuda.is_available(): + # if local rank is not given, calculate automatically + gpc.set_device(local_rank) + + # set the number of processes running on the same node + gpc.detect_num_processes_on_current_node() + + gpc.set_seed(seed) + + if verbose: + logger = get_dist_logger() + logger.info( + f'Distributed environment is initialized, ' + f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' + f'tensor parallel size: {gpc.tensor_parallel_size}', + ranks=[0]) + + +def launch_from_slurm(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables + set by SLURM + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NPROCS']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" + ) + + launch(config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_openmpi(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables + set by OpenMPI + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_torch(config: Union[str, Path, Config, Dict], + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size + from the environment variables set by PyTorch + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def initialize(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + """Core function to wrap the essential training components with our functionality based on the config which is + loaded into gpc.config. + + Args: + model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. + optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): + Your optimizer instance. + criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. + train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. + verbose (bool, optional): Whether to print logs. + + Returns: + Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): + A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` + where only ``engine`` could not be None. + """ + # get logger + logger = get_dist_logger() + gpc.verbose = verbose + + # get config from gpc + config = gpc.config + + # print config + if verbose: + logger.info( + f"\n========== Your Config ========\n" + f"{pprint.pformat(gpc.config)}\n" + f"================================\n", + ranks=[0]) + + # cudnn + cudnn_benchmark = config.get('cudnn_benchmark', False) + cudnn_deterministic = config.get('cudnn_deterministic', False) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.deterministic = cudnn_deterministic + if verbose: + logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + + # zero + use_zero = hasattr(gpc.config, 'zero') + if use_zero: + zero_cfg = gpc.config.get('zero', None) + if zero_cfg is not None: + cfg_ = zero_cfg.copy() + else: + cfg_ = {} + optimizer_config = zero_cfg.get('optimizer_config', None) + model_config = zero_cfg.get('model_config', None) + model, optimizer = convert_to_zero_v2(model, + optimizer, + model_config=model_config, + optimizer_config=optimizer_config) + + logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) + else: + if isinstance(model, nn.Module): + # first sync model across dp ranks + model.to(get_current_device()) + elif isinstance(model, Callable): + model = model().to(get_current_device()) + + # optimizer maybe a optimizer_cls + if isinstance(optimizer, Callable): + optimizer = optimizer(model.parameters()) + logger.warning("Initializing an non ZeRO model with optimizer class") + + if not use_zero: + if is_using_sequence(): + sync_model_param(model, ParallelMode.SEQUENCE_DP) + elif MOE_CONTEXT.is_initialized: + sync_moe_model_param(model) + elif is_using_ddp(): + sync_model_param(model, ParallelMode.DATA) + else: + logger.warning( + "The parameters of models is not automatically synchronized.\n" + "Please make sure that all parameters are the same in data parallel group.", + ranks=[0]) + + # check amp and zero + fp16_cfg = gpc.config.get('fp16', None) + + if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: + raise ConfigException( + "It is not allowed to set fp16 and zero configuration in your config file at the same time") + + # clip grad norm + clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + + # initialize amp + amp_mode = None + if fp16_cfg is not None and fp16_cfg.mode is not None: + cfg_ = fp16_cfg.copy() + amp_mode = cfg_.pop('mode') + if is_using_pp(): + assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' + if amp_mode == AMP_TYPE.NAIVE: + cfg_['clip_grad_norm'] = clip_grad_norm + model, optimizer, criterion = convert_to_amp(model=model, + optimizer=optimizer, + criterion=criterion, + mode=amp_mode, + amp_config=cfg_) + + # get torch ddp config + torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + + # gradient handler + gradient_handler_cfg = gpc.config.get('gradient_handler', None) + if gradient_handler_cfg is None: + # if gradient handler is not specified in the configuration file, + # check in the following order + # 1. if optimizer is ZERO, then use zero grad handler + # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp + # 3. if using pipeline and dp size larger than 1, use data parallel grad handler + if isinstance(optimizer, ShardedOptimizerV2): + gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + if verbose: + logger.info( + "Training with zero is detected, ZeROGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_ddp() and MOE_CONTEXT.is_initialized: + gradient_handler_cfg = [dict(type='MoeGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_sequence(): + model = DDP(model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', + ranks=[0]) + elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: + model = DDP(model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + elif is_using_ddp(): + gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected when using pipeline parallel, " + "DataParallelGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + # add pipeline parallel gradient handler, if pipeline shared module is detected + for param in model.parameters(): + if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if gradient_handler_cfg is None: + gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + else: + gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + if verbose: + logger.info( + "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + break + else: + if not isinstance(gradient_handler_cfg, list): + raise ConfigException( + f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" + ) + + # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time + # to avoid duplicated buffer synchronization + if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): + model.module.sync_buffer = False + + # initialize schedule for engine + if is_using_pp(): + tensor_shape = get_tensor_shape() + use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + if gpc.is_initialized(ParallelMode.PARALLEL_1D): + scatter_gather = True + else: + scatter_gather = False + if use_interleaved: + if isinstance(model, nn.Sequential): + model = nn.ModuleList([model]) + schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = NonPipelineSchedule() + + if gradient_handler_cfg is None: + gradient_handlers = None + if verbose and not isinstance(model, DDP): + logger.warning( + "No PyTorch DDP or gradient handler is set up, please make sure you do not need " + "to all-reduce the gradients after a training step.", + ranks=[0]) + else: + gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] + + # check if optimizer is OptimizerWrapper + if not isinstance(optimizer, (OptimizerWrapper, ShardedOptimizerV2)): + optimizer = OptimizerWrapper(optim=optimizer) + + # gradient accumulation + grad_accum_size = gpc.config.get('gradient_accumulation', None) + if grad_accum_size is not None: + optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + accumulate_size=grad_accum_size, + gradient_handlers=gradient_handlers, + lr_scheduler=lr_scheduler) + engine = Engine(model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule) + + return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py index 500162901905..d30ebf8d5406 100644 --- a/colossalai/legacy/nn/__init__.py +++ b/colossalai/legacy/nn/__init__.py @@ -1,4 +1,3 @@ -from ._ops import * from .layer import * from .loss import * from .metric import * diff --git a/colossalai/legacy/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py index 4991ad9a2217..9a35d02ce5ed 100644 --- a/colossalai/legacy/nn/_ops/__init__.py +++ b/colossalai/legacy/nn/_ops/__init__.py @@ -1,9 +1 @@ -from .addmm import colo_addmm -from .batch_norm import colo_batch_norm -from .element_wise import * -from .embedding import colo_embedding -from .embedding_bag import colo_embedding_bag -from .layernorm import colo_layernorm -from .linear import colo_linear -from .loss import colo_cross_entropy -from .view import colo_view +from ._utils import * diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py index 131c2154771b..a4228fa2116e 100644 --- a/colossalai/legacy/nn/_ops/_utils.py +++ b/colossalai/legacy/nn/_ops/_utils.py @@ -3,9 +3,10 @@ import torch import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.nn.layer.utils import divide -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoTensor GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] diff --git a/colossalai/legacy/nn/_ops/addmm.py b/colossalai/legacy/nn/_ops/addmm.py deleted file mode 100644 index 660b48a71d57..000000000000 --- a/colossalai/legacy/nn/_ops/addmm.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:S[1] x mat2:S[0] = Output:P - # beta * input + alpha * All-Reduce(Output) = res - - mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group()) - - # Output:P - partial_output = torch.mm(mat1, mat2) - # Reduce(Output) - output = reduce_input(partial_output, mat2.get_process_group()) - # input - assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' - output = beta * input_tensor + alpha * output - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group())) - return output - - -def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - compute_spec = mat2.compute_spec - mat1 = mat1.redistribute(ReplicaSpec()) - mat1 = reduce_grad(mat1, mat1.get_process_group()) - - output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) - output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol} - return funcs[mode](input_tensor, mat1, mat2, beta, alpha) - - -@colo_op_impl(torch.addmm) -def colo_addmm(input_tensor: GeneralTensor, - mat1: ColoTensor, - mat2: ColoTensor, - beta: Number = 1, - alpha: Number = 1, - **kargs) -> ColoTensor: - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - # At least one of the tensor should be ColoTensor - assert isinstance(mat2, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group()) - mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group()) - - # Add communication logic before and after linear call. - ret_tensor = None - if not mat2.has_compute_spec(): # No Model Parallel Applied - assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' - assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, - mat1, - mat2, - beta=beta, - alpha=alpha, - **kargs), - spec=ColoTensorSpec(mat2.get_process_group())) - elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.is_shard_1drow() and input_tensor.is_replicate(): - mode = 'row' - elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()): - mode = 'col' - else: - raise NotImplementedError - ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha) - else: - raise NotImplementedError - - return ret_tensor diff --git a/colossalai/legacy/nn/_ops/batch_norm.py b/colossalai/legacy/nn/_ops/batch_norm.py deleted file mode 100644 index 54ecc88f420a..000000000000 --- a/colossalai/legacy/nn/_ops/batch_norm.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.batch_norm) -def colo_batch_norm( - input: GeneralTensor, - running_mean: Optional[GeneralTensor], - running_var: Optional[GeneralTensor], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - training: bool = False, - momentum: float = 0.1, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - running_mean = running_mean.detach() - running_var = running_var.detach() - - input = convert_to_colo_tensor(input, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input = input.redistribute(ReplicaSpec()) - bias = bias.redistribute(ReplicaSpec()) - - output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) - output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) - return output diff --git a/colossalai/legacy/nn/_ops/element_wise.py b/colossalai/legacy/nn/_ops/element_wise.py deleted file mode 100644 index 2de51e24a6dd..000000000000 --- a/colossalai/legacy/nn/_ops/element_wise.py +++ /dev/null @@ -1,250 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def register_elementwise_op(op): - - @colo_op_impl(op) - def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs): - """ - Handles ``__torch_function__`` dispatch for the elementwise op such - as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. - This method computes on either a normal tensor or a sharded tensor. - """ - if 'inplace' in kwargs: - # TODO(jiaruifang) inplace will cause bugs - input_tensor = input_tensor.clone() - return op(input_tensor, *args, **kwargs) - else: - output = op(input_tensor, *args, **kwargs) - # return output - if isinstance(input_tensor, ColoTensor): - if isinstance(output, str): - return output - if not isinstance(output, torch.Tensor): - raise NotImplementedError - return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - - -# @colo_op_impl(torch.relu_) -# def elementwise_op(input_tensor): -# torch.relu_(input_tensor.data) -# return input_tensor - -# @colo_op_impl(Tensor.add_) -# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): -# input_tensor = input_tensor.data.add_(*args, **kwargs) -# return input_tensor - -# Tensor op -register_elementwise_op(Tensor.abs) -register_elementwise_op(Tensor.absolute) -register_elementwise_op(Tensor.acos) -register_elementwise_op(Tensor.arccos) -register_elementwise_op(Tensor.angle) -register_elementwise_op(Tensor.asin) -register_elementwise_op(Tensor.arcsin) -register_elementwise_op(Tensor.atan) -register_elementwise_op(Tensor.arctan) -register_elementwise_op(Tensor.all) -register_elementwise_op(Tensor.any) -register_elementwise_op(Tensor.bernoulli) -register_elementwise_op(Tensor.bfloat16) -register_elementwise_op(Tensor.bitwise_not) -register_elementwise_op(Tensor.bool) -register_elementwise_op(Tensor.byte) -register_elementwise_op(Tensor.ceil) -register_elementwise_op(Tensor.char) -register_elementwise_op(Tensor.clamp) -register_elementwise_op(Tensor.clamp_max) -register_elementwise_op(Tensor.clamp_min) -register_elementwise_op(Tensor.clip) -register_elementwise_op(Tensor.clone) -register_elementwise_op(Tensor.contiguous) -register_elementwise_op(Tensor.copysign) -register_elementwise_op(Tensor.cos) -register_elementwise_op(Tensor.cosh) -register_elementwise_op(Tensor.acosh) -register_elementwise_op(Tensor.arccosh) -register_elementwise_op(Tensor.cpu) -register_elementwise_op(Tensor.cuda) -register_elementwise_op(Tensor.deg2rad) -register_elementwise_op(Tensor.detach) -register_elementwise_op(Tensor.digamma) -register_elementwise_op(Tensor.double) -register_elementwise_op(Tensor.erf) -register_elementwise_op(Tensor.erfc) -register_elementwise_op(Tensor.erfinv) -register_elementwise_op(Tensor.exp) -register_elementwise_op(Tensor.expm1) -register_elementwise_op(Tensor.fix) -register_elementwise_op(Tensor.trunc) -register_elementwise_op(Tensor.float) -register_elementwise_op(Tensor.float_power) -register_elementwise_op(Tensor.floor) -register_elementwise_op(Tensor.frac) -register_elementwise_op(Tensor.half) -register_elementwise_op(Tensor.hardshrink) -register_elementwise_op(Tensor.heaviside) -register_elementwise_op(Tensor.i0) -register_elementwise_op(Tensor.int) -register_elementwise_op(Tensor.isfinite) -register_elementwise_op(Tensor.isinf) -register_elementwise_op(Tensor.isposinf) -register_elementwise_op(Tensor.isneginf) -register_elementwise_op(Tensor.isnan) -register_elementwise_op(Tensor.lgamma) -register_elementwise_op(Tensor.log) -register_elementwise_op(Tensor.log10) -register_elementwise_op(Tensor.log1p) -register_elementwise_op(Tensor.log2) -register_elementwise_op(Tensor.logical_not) -register_elementwise_op(Tensor.logit) -register_elementwise_op(Tensor.long) -register_elementwise_op(Tensor.nan_to_num) -register_elementwise_op(Tensor.neg) -register_elementwise_op(Tensor.negative) -register_elementwise_op(Tensor.positive) -register_elementwise_op(Tensor.pow) -register_elementwise_op(Tensor.rad2deg) -register_elementwise_op(Tensor.reciprocal) -register_elementwise_op(Tensor.round) -register_elementwise_op(Tensor.rsqrt) -register_elementwise_op(Tensor.short) -register_elementwise_op(Tensor.sigmoid) -register_elementwise_op(Tensor.sign) -register_elementwise_op(Tensor.signbit) -register_elementwise_op(Tensor.sgn) -register_elementwise_op(Tensor.sin) -register_elementwise_op(Tensor.sinc) -register_elementwise_op(Tensor.sinh) -register_elementwise_op(Tensor.asinh) -register_elementwise_op(Tensor.arcsinh) -register_elementwise_op(Tensor.sqrt) -register_elementwise_op(Tensor.square) -register_elementwise_op(Tensor.to) -register_elementwise_op(Tensor.tan) -register_elementwise_op(Tensor.tanh) -register_elementwise_op(Tensor.atanh) -register_elementwise_op(Tensor.arctanh) -register_elementwise_op(Tensor.type) -register_elementwise_op(Tensor.type_as) - -# torch OP -register_elementwise_op(torch.abs) -register_elementwise_op(torch.absolute) -register_elementwise_op(torch.acos) -register_elementwise_op(torch.arccos) -register_elementwise_op(torch.angle) -register_elementwise_op(torch.asin) -register_elementwise_op(torch.arcsin) -register_elementwise_op(torch.atan) -register_elementwise_op(torch.arctan) -register_elementwise_op(torch.all) -register_elementwise_op(torch.any) -register_elementwise_op(torch.bernoulli) -register_elementwise_op(torch.bitwise_not) -register_elementwise_op(torch.ceil) -register_elementwise_op(torch.clamp) -register_elementwise_op(torch.clamp_max) -register_elementwise_op(torch.clamp_min) -register_elementwise_op(torch.clip) -register_elementwise_op(torch.clone) -register_elementwise_op(torch.copysign) -register_elementwise_op(torch.cos) -register_elementwise_op(torch.cosh) -register_elementwise_op(torch.acosh) -register_elementwise_op(torch.arccosh) -register_elementwise_op(torch.deg2rad) -register_elementwise_op(torch.digamma) -register_elementwise_op(torch.erf) -register_elementwise_op(torch.erfc) -register_elementwise_op(torch.erfinv) -register_elementwise_op(torch.exp) -register_elementwise_op(torch.expm1) -register_elementwise_op(torch.fix) -register_elementwise_op(torch.trunc) -register_elementwise_op(torch.float_power) -register_elementwise_op(torch.floor) -register_elementwise_op(torch.frac) -register_elementwise_op(torch.hardshrink) -register_elementwise_op(torch.heaviside) -register_elementwise_op(torch.i0) -register_elementwise_op(torch.isfinite) -register_elementwise_op(torch.isinf) -register_elementwise_op(torch.isposinf) -register_elementwise_op(torch.isneginf) -register_elementwise_op(torch.isnan) -register_elementwise_op(torch.lgamma) -register_elementwise_op(torch.log) -register_elementwise_op(torch.log10) -register_elementwise_op(torch.log1p) -register_elementwise_op(torch.log2) -register_elementwise_op(torch.logical_not) -register_elementwise_op(torch.logit) -register_elementwise_op(torch.nan_to_num) -register_elementwise_op(torch.neg) -register_elementwise_op(torch.negative) -register_elementwise_op(torch.positive) -register_elementwise_op(torch.pow) -register_elementwise_op(torch.rad2deg) -register_elementwise_op(torch.reciprocal) -register_elementwise_op(torch.round) -register_elementwise_op(torch.rsqrt) -register_elementwise_op(torch.sigmoid) -register_elementwise_op(torch.sign) -register_elementwise_op(torch.signbit) -register_elementwise_op(torch.sgn) -register_elementwise_op(torch.sin) -register_elementwise_op(torch.sinc) -register_elementwise_op(torch.sinh) -register_elementwise_op(torch.asinh) -register_elementwise_op(torch.arcsinh) -register_elementwise_op(torch.sqrt) -register_elementwise_op(torch.square) -register_elementwise_op(torch.tan) -register_elementwise_op(torch.tanh) -register_elementwise_op(torch.atanh) -register_elementwise_op(torch.arctanh) -register_elementwise_op(torch.zeros_like) - -# nn.functional OP -register_elementwise_op(F.threshold) -register_elementwise_op(F.relu) -register_elementwise_op(F.hardtanh) -register_elementwise_op(F.hardswish) -register_elementwise_op(F.relu6) -register_elementwise_op(F.elu) -register_elementwise_op(F.selu) -register_elementwise_op(F.celu) -register_elementwise_op(F.leaky_relu) -register_elementwise_op(F.prelu) -register_elementwise_op(F.rrelu) -register_elementwise_op(F.gelu) -register_elementwise_op(F.logsigmoid) -register_elementwise_op(F.hardshrink) -register_elementwise_op(F.tanhshrink) -register_elementwise_op(F.softsign) -register_elementwise_op(F.softplus) -register_elementwise_op(F.softmin) -register_elementwise_op(F.softmax) -register_elementwise_op(F.softshrink) -register_elementwise_op(F.gumbel_softmax) -register_elementwise_op(F.log_softmax) -register_elementwise_op(F.tanh) -register_elementwise_op(F.sigmoid) -register_elementwise_op(F.hardsigmoid) -register_elementwise_op(F.silu) -register_elementwise_op(F.mish) -# TODO(ver217): dropout handles seed -register_elementwise_op(F.dropout) -register_elementwise_op(F.alpha_dropout) -register_elementwise_op(F.feature_alpha_dropout) diff --git a/colossalai/legacy/nn/_ops/embedding.py b/colossalai/legacy/nn/_ops/embedding.py deleted file mode 100644 index b145d1763380..000000000000 --- a/colossalai/legacy/nn/_ops/embedding.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input - - -def colo_embedding_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - compute_spec = weight.compute_spec - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_1Drow(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim] - # get the index of current segment and mask other segments with 0 - - # get complete input tensor through all-gather - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = weight.get_process_group().tp_local_rank() - num_embeddings_per_partition = weight.size_local(0) - vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition - vocab_end_index = vocab_start_index + num_embeddings_per_partition - - # build the mask. - input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index) - # mask the input. - # TODO(jzy) masked_input may be an activation managed by ColoTensor. - masked_input = input_tensor - vocab_start_index - masked_input[input_mask] = 0 - - partial_output = F.embedding(masked_input, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - # Mask the output embedding. - partial_output[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(partial_output, weight.get_process_group()) - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec())) - return output - - -def colo_embedding_1d(mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol} - return funcs[mode](input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - -@colo_op_impl(F.embedding) -def colo_embedding(input_tensor: GeneralTensor, - weight: GeneralTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1drow(): - mode = 'row' - elif weight.is_shard_1dcol(): - mode = 'col' - else: - raise NotImplementedError - return colo_embedding_1d(mode, - input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/embedding_bag.py b/colossalai/legacy/nn/_ops/embedding_bag.py deleted file mode 100644 index 9a656d5871a3..000000000000 --- a/colossalai/legacy/nn/_ops/embedding_bag.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F -from torch import Tensor - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if weight.compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_bag_1d(tp_mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - assert tp_mode in ('col',) - funcs = {'col': colo_embedding_bag_1Dcol} - return funcs[tp_mode](input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - - -@colo_op_impl(F.embedding_bag) -def colo_embedding_bag(input_tensor: GeneralTensor, - weight: GeneralTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - # Handle different parallel actions. - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol(): - tp_mode = 'col' - else: - raise NotImplementedError - return colo_embedding_bag_1d(tp_mode, - input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/layernorm.py b/colossalai/legacy/nn/_ops/layernorm.py deleted file mode 100644 index 9960c5d48096..000000000000 --- a/colossalai/legacy/nn/_ops/layernorm.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import List, Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.layer_norm) -def colo_layernorm( - input_tensor: GeneralTensor, - normalized_shape: List[int], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor(tensor=output, - spec=ColoTensorSpec(pg=input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - return output diff --git a/colossalai/legacy/nn/_ops/linear.py b/colossalai/legacy/nn/_ops/linear.py deleted file mode 100644 index 2f2088c61fa8..000000000000 --- a/colossalai/legacy/nn/_ops/linear.py +++ /dev/null @@ -1,171 +0,0 @@ -from copy import deepcopy -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor.sharding_spec import ShardingSpec - -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:S[1] x Weight:S[0] = Output:P - # All-Reduce(Output) + bias = res - # Input:S[1] - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg) - - # Output:P - partial_output = F.linear(input_tensor, weight) - # Reduce(Output) - - output = reduce_input(partial_output, pg) - # Bias - if bias is not None: - assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' - output = output + bias - - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec())) - return output - - -def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] - # All-Gather(Output) - # Input:B - compute_spec = weight.compute_spec - input_tensor = input_tensor.redistribute(ReplicaSpec()) - input_parallel = reduce_grad(input_tensor, weight.get_process_group()) - - output_parallel = F.linear(input_parallel, weight, bias) - output = ColoTensor.from_torch_tensor(output_parallel, - spec=ColoTensorSpec(weight.get_process_group(), - ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D))) - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - assert mode in ('row', 'col') - funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol} - return funcs[mode](input_tensor, weight, bias) - - -# @register_colo_graph(input_pos=[1], param_pos=[2, 3]) -def colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - assert isinstance(weight, ColoTensor) - pg = weight.get_process_group() - assert pg - input_tensor = convert_to_colo_tensor(input_tensor, pg) - bias = convert_to_colo_tensor(bias, pg) - # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) - - # Add communication logic before and after linear call. - ret_tensor = None - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native Linear op' - assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' - ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): - mode = 'row' - elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()): - mode = 'col' - else: - raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}") - ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) - else: - raise NotImplementedError - - return ret_tensor - - -def _new_colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """ - A tentative function to compute the distributed linear layer with the latest sharding spec. - This function is subject to future change as the current sharding API is not stable. - """ - # get mesh info - input_sharding_seq = input_tensor.sharding_spec.sharding_sequence - weight_sharding_seq = weight.sharding_spec.sharding_sequence - if bias is not None: - bias_sharding_seq = bias.sharding_spec.sharding_sequence - device_mesh = weight.sharding_spec.device_mesh - pg_axis0 = weight.pg_axis0 - pg_axis1 = weight.pg_axis1 - - # the last dim of input should have the same spec as the first dim of weight - # the weight is transposed, so we look at the second dimension - assert input_sharding_seq[-1] == weight_sharding_seq[1] - - if bias is not None: - assert bias_sharding_seq[0] == weight_sharding_seq[0] - - # compute the output sharding sequence - # as weight is transposed, so we look at the first dimension - output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] - output_shard_seq = deepcopy(output_shard_seq) - - # TODO: add reduce grad logic - - # handle column and row parallel linear - # by reusing the implementation above - out = F.linear(input_tensor, weight) - - # run all reduce if necessary - last_dim_spec = input_sharding_seq[-1] - if last_dim_spec.is_replica: - pass - elif last_dim_spec.shard_list is not None: - for dim in last_dim_spec.shard_list: - if dim == 0: - reduce_input(out, pg_axis0) - elif dim == 1: - reduce_input(out, pg_axis1) - else: - raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") - # add bias - if bias is not None: - out += bias - - # convert shard seq to partition dict - output_partition_dict = {} - for index, dim_spec in enumerate(output_shard_seq): - if not dim_spec.is_replica: - if index not in output_partition_dict: - output_partition_dict[index] = [] - output_partition_dict[index].extend(dim_spec.shard_list) - - entire_shape = out.shape - output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) - ret_tensor = ColoTensor.from_torch_tensor(out) - setattr(ret_tensor, 'sharding_spec', output_sharding_spec) - return ret_tensor - - -def _has_sharding_spec(tensor): - """ - A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is - set as the attribute `sharding_spec` on a tensor. - """ - return hasattr(tensor, 'sharding_spec') - - -@colo_op_impl(F.linear) -def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - if _has_sharding_spec(weight): - return _new_colo_linear_imp(input, weight, bias) - else: - return colo_linear_imp(input, weight, bias) diff --git a/colossalai/legacy/nn/_ops/loss.py b/colossalai/legacy/nn/_ops/loss.py deleted file mode 100644 index 90efbfa36f2a..000000000000 --- a/colossalai/legacy/nn/_ops/loss.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional - -import torch -import torch.nn.functional as F - -from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.cross_entropy) -def colo_cross_entropy(input_tensor: GeneralTensor, - target: GeneralTensor, - weight: Optional[GeneralTensor] = None, - size_average: Optional[bool] = None, - ignore_index: int = -100, - reduce: Optional[bool] = None, - reduction: str = "mean", - label_smoothing: float = 0.0): - assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) - pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) - weight = convert_to_colo_tensor(weight, pg) - target = convert_to_colo_tensor(target, pg) - input_tensor = convert_to_colo_tensor(input_tensor, pg) - - if input_tensor.is_replicate(): # Input is gathered - assert target.is_replicate() and (weight is None or weight.is_replicate()), \ - "Target tensor and weight tensor both should be complete" - output = F.cross_entropy(input_tensor, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - label_smoothing=label_smoothing) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - elif input_tensor.has_compute_spec(): # Single Model Parallel Applied - if input_tensor.is_shard_1dcol(): - assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" - assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" - output = VocabParallelCrossEntropyLoss1D()(input_tensor, - target, - process_group=input_tensor.process_group.tp_process_group()) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - else: - raise NotImplementedError - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/view.py b/colossalai/legacy/nn/_ops/view.py deleted file mode 100644 index 3c0bc52337ce..000000000000 --- a/colossalai/legacy/nn/_ops/view.py +++ /dev/null @@ -1,96 +0,0 @@ -import operator -from functools import reduce -from typing import Optional, Union - -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - - -def _all_int(my_iter): - return all(isinstance(i, int) for i in my_iter) - - -def _get_valid_shape(shape): - if isinstance(shape, list): - if _all_int(shape): - return tuple(shape) - else: - raise RuntimeError("expects type(int) but finds an other type") - elif isinstance(shape, tuple): - if _all_int(shape): - return shape - else: - return _get_valid_shape(shape[0]) - else: - raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) - - -def _shape_infer(org_sp, tgt_sp): - cnt = 0 - pos = 0 - for idx, dim in enumerate(tgt_sp): - if dim < -1: - raise RuntimeError("invalid shape dimension {}".format(dim)) - elif dim == -1: - cnt += 1 - pos = idx - - if cnt > 1: - raise RuntimeError("only one dimension can be inferred") - - org_prod = reduce(operator.mul, org_sp, 1) - tgt_prod = reduce(operator.mul, tgt_sp, 1) - - if cnt == 0: - if org_prod != tgt_prod: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - else: - return tgt_sp - elif org_prod % tgt_prod != 0: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - - infer_dim = -(org_prod // tgt_prod) - return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:] - - -@colo_op_impl(torch.Tensor.view) -def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. - Changes the shape of the current tensor. - """ - assert isinstance(self, ColoTensor) - # apply original `view` function for replicated colo tensors - if self.is_replicate(): - return self.view(*shape) - - cur_sp = self.size() - org_sp = self.size_global() - # parse the passed arguments - tgt_sp = _get_valid_shape(shape) - # get the correct shape from inference - inf_sp = _shape_infer(org_sp, tgt_sp) - - if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: - new_shape = (cur_sp[0],) + tgt_sp[1:] - res = self.view(*new_shape) - elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: - new_shape = tgt_sp[:-1] + (cur_sp[-1],) - res = self.view(*new_shape) - else: - replicated_t = self.redistribute(dist_spec=ReplicaSpec()) - return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape), - spec=ColoTensorSpec(self.get_process_group())) - - return ColoTensor.from_torch_tensor(tensor=res, - spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec)) - - -@colo_op_impl(torch.Tensor.size) -def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: - size = self.size_global() - if dim is None: - return size - else: - return size[dim] diff --git a/colossalai/legacy/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py index 4a06bdcb7629..01fd9b3e8943 100644 --- a/colossalai/legacy/nn/layer/base_layer.py +++ b/colossalai/legacy/nn/layer/base_layer.py @@ -5,8 +5,8 @@ import torch.nn as nn -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc class ParallelLayer(nn.Module): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py index 0c049cb3f408..7b0481a3f53c 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py @@ -1,6 +1,6 @@ import torch.nn as nn -from colossalai.context import ParallelMode, seed +from colossalai.legacy.context import ParallelMode, seed from ..parallel_1d import * from ..utils import get_tensor_parallel_mode diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index 300baf9c12ba..db9dfa3667b4 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc try: import fused_mix_prec_layer_norm_cuda diff --git a/colossalai/legacy/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py index fddf4e73db51..15b41e305cba 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from ..utils import divide diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index c0a169c1596f..db7986b8e8e5 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -10,18 +10,18 @@ from torch import Tensor from torch.nn.parameter import Parameter -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env from colossalai.kernel import LayerNorm from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.context.parallel_context import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import LAYERS -from colossalai.nn import init as init -from colossalai.utils.checkpointing import ( +from colossalai.legacy.utils.checkpointing import ( broadcast_state_dict, gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict, ) +from colossalai.nn import init as init from colossalai.utils.cuda import get_current_device from ..base_layer import ParallelLayer diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index fa9b49bcf53f..43e14d4a47a5 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.utils import get_current_device @@ -31,9 +31,9 @@ def matmul_2d( out_shape (:class:`torch.size`): shape of output tensor. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional): row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW. - col_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional): column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. Returns: @@ -146,8 +146,8 @@ def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape (:class:`torch.size`): shape of output tensor. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -172,8 +172,8 @@ class Matmul_AB_2D(torch.autograd.Function): out_shape (:class:`torch.size`): shape of output tensor. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -299,8 +299,8 @@ class Matmul_ABT_2D(torch.autograd.Function): out_shape (:class:`torch.size`): shape of output tensor. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank @@ -433,8 +433,8 @@ class Matmul_ATB_2D(torch.autograd.Function): out_shape (:class:`torch.size`): shape of output tensor. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -620,8 +620,8 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro output_size_per_partition (int): size of output per partition. row_rank (int, optional): the rank of row, defaults to None. col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. data_parallel_rank (int): data parallel rank. @@ -685,8 +685,8 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r E_x (:class:`torch.tensor`): mean. Var_x (:class:`torch.tensor`): variance. hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -719,7 +719,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) Args: tensor (:class:`torch.tensor`): Input tensor. dim (int): Dimension to gather. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -767,7 +767,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: Args: input_ (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -795,7 +795,7 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo Args: tensor (:class:`torch.tensor`): Input tensor. dim (int): Dimension to reduce. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found diff --git a/colossalai/legacy/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py index 012fec41c802..87ba1bf69691 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_utils.py @@ -1,6 +1,6 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env def get_summa_dim_from_env() -> int: diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index b458d15c78e7..893bc74b57d9 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -8,13 +8,16 @@ from torch import Tensor from torch.nn import Parameter -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import LAYERS +from colossalai.legacy.utils.checkpointing import ( + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.nn import init as init -from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.cuda import get_current_device from ..base_layer import ParallelLayer diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 55defa4a328d..1226162ae399 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -5,9 +5,9 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.utils import get_current_device @@ -112,8 +112,8 @@ def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: T out_shape (:class:`torch.size`): shape of output tensor. row_rank (int): the rank of row. col_rank (int): the rank of column. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -139,8 +139,8 @@ class Matmul_AB_2p5D(torch.autograd.Function): row_rank (int): the rank of row. col_rank (int): the rank of column. dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -264,8 +264,8 @@ class Matmul_ABT_2p5D(torch.autograd.Function): row_rank (int): the rank of row. col_rank (int): the rank of column. dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -394,8 +394,8 @@ class Matmul_ATB_2p5D(torch.autograd.Function): row_rank (int): the rank of row. col_rank (int): the rank of column. dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. data_parallel_rank (int): data parallel rank. pipeline_parallel_rank (int): pipeline parallel rank pipeline_parallel_size (int): pipeline parallel size. @@ -606,7 +606,7 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t row_rank (int): the rank of row. col_rank (int): the rank of column. dep_rank (int): the rank of depth. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. data_parallel_rank (int): data parallel rank. @@ -631,7 +631,7 @@ class _Layernorm2p5D(torch.autograd.Function): E_x (:class:`torch.tensor`): mean. Var_x (:class:`torch.tensor`): variance. hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -682,7 +682,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, E_x (:class:`torch.tensor`): mean. Var_x (:class:`torch.tensor`): variance. hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -715,7 +715,7 @@ def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: Parallel Args: inputs (:class:`torch.tensor`): input tensor. dim (int): dimension of all-gather. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -730,7 +730,7 @@ class SplitFirst(torch.autograd.Function): Args: inputs (:class:`torch.tensor`): input tensor. tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -798,7 +798,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: Args: input_ (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -826,7 +826,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel Args: input_ (:class:`torch.tensor`): Input tensor. dim (int): Dimension to reduce. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py index 1478b25de618..69a350a977ac 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py @@ -1,6 +1,6 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env def get_tesseract_dim_dep_from_env(): diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index 04acc2bb0f4c..b4aa9f16ddf0 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -8,17 +8,17 @@ from torch import Tensor from torch.nn import Parameter -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import LAYERS -from colossalai.nn import init as init -from colossalai.utils.checkpointing import ( +from colossalai.legacy.utils.checkpointing import ( broadcast_state_dict, gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict, ) +from colossalai.nn import init as init from colossalai.utils.cuda import get_current_device from ..base_layer import ParallelLayer diff --git a/colossalai/legacy/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py index ca0b0e62783a..c6374efb7124 100755 --- a/colossalai/legacy/nn/layer/parallel_3d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py @@ -7,10 +7,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from ._utils import get_parallel_mode_from_env, push_async_grad @@ -73,9 +73,9 @@ def linear_3d( Args: input_ (:class:`torch.tensor`): input matrix. weight (:class:`torch.tensor`): matrix of weight. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -166,9 +166,9 @@ def classifier_3d( input_ (:class:`torch.tensor`): input matrix. weight (:class:`torch.tensor`): matrix of weight. bias (:class:`torch.tensor`): matrix of bias. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -260,9 +260,9 @@ def vocab_parallel_classifier_3d( input_ (:class:`torch.tensor`): input matrix. weight (:class:`torch.tensor`): matrix of weight. bias (:class:`torch.tensor`): matrix of bias. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -378,8 +378,8 @@ def layernorm_3d( If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float): a value added to the denominator for numerical stability - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input x weight parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -404,7 +404,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te Args: tensor (:class:`torch.tensor`): Input tensor. dim (int): Specified dimension in which to split. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode. Returns: :class:`torch.tensor`: The tensor has been split. @@ -434,8 +434,8 @@ def split_batch_3d(input_: Tensor, Args: input_ (:class:`torch.tensor`): Input tensor. dim (int): Specified dimension in which to split. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): weight parallel mode. Returns: :class:`torch.tensor`: The tensor has been split. @@ -471,7 +471,7 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: Args: tensor (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -501,7 +501,7 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) Args: tensor (:class:`torch.tensor`): Input tensor. dim (int): Dimension to gather. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -530,7 +530,7 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo Args: tensor (:class:`torch.tensor`): Input tensor. dim (int): Dimension to scatter. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -578,8 +578,8 @@ def reduce_by_batch_3d(tensor: Tensor, r"""All-reduce the input from the model parallel region. Args: - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. reduce_mean (bool, optional): If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False. diff --git a/colossalai/legacy/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py index 364191a79f88..cb300c2a9684 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_utils.py @@ -4,9 +4,15 @@ import torch from torch import Tensor -from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.constants import ( + INPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_GROUP_3D, + OUTPUT_X_WEIGHT_3D, + WEIGHT_GROUP_3D, +) +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env def get_depth_from_env() -> int: diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index b815a842ca52..d6aaa427b9e6 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -8,19 +8,25 @@ from torch import Tensor from torch.nn import Parameter -from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.communication import all_reduce, broadcast +from colossalai.legacy.constants import ( + INPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_GROUP_3D, + OUTPUT_X_WEIGHT_3D, + WEIGHT_GROUP_3D, +) +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.nn.layer.base_layer import ParallelLayer from colossalai.legacy.registry import LAYERS -from colossalai.nn import init as init -from colossalai.utils.checkpointing import ( +from colossalai.legacy.utils.checkpointing import ( broadcast_state_dict, gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict, ) +from colossalai.nn import init as init from colossalai.utils.cuda import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index fcf2962017a3..ea1863f0b474 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -5,9 +5,9 @@ from torch import distributed as dist from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.legacy.communication import ring_forward +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range from colossalai.utils import get_current_device diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index e44e61c2fb7d..033c1be962ae 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -9,11 +9,11 @@ from torch.nn import Parameter import colossalai -from colossalai.context import seed -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.kernel import FusedScaleMaskSoftmax from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.legacy.context import seed +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.registry import LAYERS diff --git a/colossalai/legacy/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py index d8f3ad2a7eca..3148a0bed570 100644 --- a/colossalai/legacy/nn/layer/utils/common.py +++ b/colossalai/legacy/nn/layer/utils/common.py @@ -8,9 +8,9 @@ import torch from torch import Tensor, nn -from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.utils import checkpoint +from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.utils import checkpoint class CheckpointModule(nn.Module): diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 0e11fc4d0dab..71ca1d421de6 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -7,7 +7,7 @@ from torch import nn as nn from torch.nn.parameter import Parameter -from colossalai.context import seed +from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init from colossalai.utils.cuda import get_current_device @@ -64,7 +64,7 @@ class WrappedDropout(nn.Module): Args: p (float, optional): probability of an element to be zeroed, defaults 0.5. inplace (bool, optional): whether to do dropout in-place, default to be False. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module): Args: p (float, optional): probability of dropping path, defaults 0.0. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found diff --git a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py index 68fea8622c5c..ec19d1b707d8 100644 --- a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py @@ -3,8 +3,8 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc class PipelineSharedModuleWrapper: diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py index 1bd8872d9c3a..abb7ec3ef824 100644 --- a/colossalai/legacy/nn/loss/__init__.py +++ b/colossalai/legacy/nn/loss/__init__.py @@ -2,7 +2,7 @@ from torch.nn.modules.loss import * from torch.nn.modules.loss import _Loss -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode from .loss_1d import VocabParallelCrossEntropyLoss1D diff --git a/colossalai/legacy/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py index 8c9483fccaec..2582e8b359d5 100644 --- a/colossalai/legacy/nn/loss/loss_1d.py +++ b/colossalai/legacy/nn/loss/loss_1d.py @@ -3,8 +3,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.modules.loss import _Loss -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import LOSSES diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 6191602b71ee..7ab58415608a 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -4,8 +4,8 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.registry import LOSSES diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index 2746b201152c..8a5d04a8c788 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -4,8 +4,8 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.registry import LOSSES diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index 2aeb1bd9825d..a576d84f71cd 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -4,8 +4,8 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss -from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.core import global_context as gpc +from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.registry import LOSSES diff --git a/colossalai/legacy/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py index 1aaac73ecabd..675f5c2b5120 100644 --- a/colossalai/legacy/nn/metric/accuracy_3d.py +++ b/colossalai/legacy/nn/metric/accuracy_3d.py @@ -1,7 +1,7 @@ import torch from torch import nn -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env diff --git a/colossalai/legacy/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py index f839d6b28444..2b2ad36a74f4 100644 --- a/colossalai/legacy/nn/parallel/data_parallel.py +++ b/colossalai/legacy/nn/parallel/data_parallel.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist -from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.legacy.tensor import ProcessGroup as ColoProcessGroup from colossalai.utils import is_ddp_ignored from .reducer import Reducer @@ -34,8 +34,8 @@ class ColoDDP(torch.nn.Module): """Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now. Example: - >>> from colossalai.core import global_context as gpc - >>> from colossalai.context import ParallelMode + >>> from colossalai.legacy.core import global_context as gpc + >>> from colossalai.legacy.context import ParallelMode >>> model = torch.nn.Linear(20, 1) >>> pg = ProcessGroup(tp_degree = world_size//2) >>> model = ColoDDP(model, pg) diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index 79d7672b26bc..522fb4f4497f 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -4,7 +4,8 @@ import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec +from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec +from colossalai.tensor import ColoParameter, ColoTensor from .cache_mgr import CachedParamMgr, EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index 116d836b7139..a1feda2bdb0e 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index 0014c784fba1..8017ee72b0b4 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -7,7 +7,7 @@ from torch.profiler import record_function from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py index a0a3eb40cf08..69d92afaaa94 100644 --- a/colossalai/legacy/nn/parallel/layers/colo_module.py +++ b/colossalai/legacy/nn/parallel/layers/colo_module.py @@ -1,7 +1,7 @@ from typing import Dict, List -from colossalai.tensor import ComputePattern -from colossalai.tensor.distspec import _DistSpec +from colossalai.legacy.tensor import ComputePattern +from colossalai.legacy.tensor.distspec import _DistSpec class ColoModule(object): diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py index 3e4e7ffd8de7..4796699fc57f 100644 --- a/colossalai/legacy/nn/parallel/layers/embedding.py +++ b/colossalai/legacy/nn/parallel/layers/embedding.py @@ -1,4 +1,4 @@ -from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec from .colo_module import ColoModule diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py index e391cf808933..51a8d4c976a6 100644 --- a/colossalai/legacy/nn/parallel/layers/linear.py +++ b/colossalai/legacy/nn/parallel/layers/linear.py @@ -1,4 +1,4 @@ -from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec from .colo_module import ColoModule diff --git a/colossalai/legacy/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py index 191266fa70fd..09326d2d6f9a 100644 --- a/colossalai/legacy/nn/parallel/layers/module_utils.py +++ b/colossalai/legacy/nn/parallel/layers/module_utils.py @@ -2,7 +2,8 @@ import torch -from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec +from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec +from colossalai.tensor import ColoParameter from . import ColoModule diff --git a/colossalai/legacy/pipeline/__init__.py b/colossalai/legacy/pipeline/__init__.py new file mode 100644 index 000000000000..f36f54ac9307 --- /dev/null +++ b/colossalai/legacy/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .layer_spec import LayerSpec +from .pipelinable import PipelinableContext, PipelinableModel + +__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/legacy/pipeline/layer_spec.py similarity index 97% rename from colossalai/pipeline/layer_spec.py rename to colossalai/legacy/pipeline/layer_spec.py index 7e9169efff78..3960debd7f72 100644 --- a/colossalai/pipeline/layer_spec.py +++ b/colossalai/legacy/pipeline/layer_spec.py @@ -1,9 +1,11 @@ import torch + from colossalai.utils.model.utils import call_to_str + class LayerSpec: """ - + """ def __init__(self, typename, *module_args, **module_kwargs): @@ -52,4 +54,4 @@ def count_params(self): return self._param_count def reset_param_count(self): - self._param_count = 0 \ No newline at end of file + self._param_count = 0 diff --git a/colossalai/legacy/pipeline/middleware/__init__.py b/colossalai/legacy/pipeline/middleware/__init__.py new file mode 100644 index 000000000000..481741bfee31 --- /dev/null +++ b/colossalai/legacy/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo + +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py similarity index 62% rename from colossalai/pipeline/middleware/adaptor/__init__.py rename to colossalai/legacy/pipeline/middleware/adaptor/__init__.py index 949700a2c49d..0b0d36d2ffe5 100644 --- a/colossalai/pipeline/middleware/adaptor/__init__.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py @@ -1,3 +1,3 @@ from .fx import get_topology as get_fx_topology -__all__ = ['get_fx_topology'] \ No newline at end of file +__all__ = ['get_fx_topology'] diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/legacy/pipeline/middleware/adaptor/fx.py similarity index 92% rename from colossalai/pipeline/middleware/adaptor/fx.py rename to colossalai/legacy/pipeline/middleware/adaptor/fx.py index 8437c5194762..8cc40f120f15 100644 --- a/colossalai/pipeline/middleware/adaptor/fx.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/fx.py @@ -1,6 +1,8 @@ -from torch.fx.graph_module import GraphModule -from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo import torch +from torch.fx.graph_module import GraphModule + +from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo + def partition_name_to_id(partition_name, is_input=False, is_output=False): if is_input: @@ -12,6 +14,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): partition_id = int(partition_name.split(prefix)[-1]) + 2 return partition_id + # There are two kinds of def in fx.graph # 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. # e.g. submod1 = call_module(...) @@ -20,6 +23,8 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): # 2. direct_use & direct_def, which means the output is used by next partition directly. # e.g. submod1 = call_module(...) # submod2 = call_module(submod1, ...) + + def find_input_in_partition(node, partitions, input_partitions=None): p_input_val = None direct_def = not node.name.startswith('getitem') @@ -45,9 +50,10 @@ def find_input_in_partition(node, partitions, input_partitions=None): partition_id = partition_name_to_id(partition.name) p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) return p_input_val - + return p_input_val - + + def find_output_in_partition(node, partitions, output_partitions=None): p_output_val = PartitionOutputVal() for user in node.users: @@ -70,7 +76,7 @@ def find_output_in_partition(node, partitions, output_partitions=None): if arg == user: p_output_val.add(partition_id=partition_id, offset=i) break - + # user is output if output_partitions is not None: output_node = output_partitions[0] @@ -84,10 +90,11 @@ def find_output_in_partition(node, partitions, output_partitions=None): break return p_output_val + def get_topology(gm: GraphModule): topo = Topo() topo_output_partition = Partition() - + input_partitions = [] partitions = [] output_partitions = [] @@ -109,7 +116,7 @@ def get_topology(gm: GraphModule): topo_input_partition.add_output_val(p_output_val) topo.set_partitions(partition_id=0, partition=topo_input_partition) topo.set_input_partition_id(partition_id=0) - + for i, partition in enumerate(partitions): topo_mid_partition = Partition() # set input for submodule @@ -131,15 +138,16 @@ def get_topology(gm: GraphModule): for user in partition.users: cur_node = user p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) - topo_mid_partition.add_output_val(p_output_val) - topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) - + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition) + # set input for output_partition for partition in output_partitions: topo_output_partition = Partition() - torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( - find_input_in_partition(n, partitions, input_partitions))) + torch.fx.graph.map_arg( + partition.args[0], + lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions))) topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_output_partition_id(partition_id=1) - return topo \ No newline at end of file + return topo diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/legacy/pipeline/middleware/topo.py similarity index 95% rename from colossalai/pipeline/middleware/topo.py rename to colossalai/legacy/pipeline/middleware/topo.py index e798e2ed9cab..3c21cce6dc0e 100644 --- a/colossalai/pipeline/middleware/topo.py +++ b/colossalai/legacy/pipeline/middleware/topo.py @@ -1,49 +1,54 @@ -from typing import Dict, List from dataclasses import dataclass +from typing import Dict, List # This file includes data structure used by Pipeline Middleware. + @dataclass class ValPosition: partition_id: int offset: int - + def __str__(self) -> str: res = f'[partition_id:{self.partition_id},offset:{self.offset}]' return res - + def __repr__(self) -> str: return self.__str__() + class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: # every input from which partition_id and which offset val_pos = ValPosition(partition_id, offset) self._from_partition_and_offset: ValPosition = val_pos - + def get(self): return self._from_partition_and_offset - + def __str__(self) -> str: res = '' res += f'<-({self._from_partition_and_offset})' return res - + def __repr__(self) -> str: return self.__str__() - + + class PartitionOutputVal(object): + def __init__(self) -> None: # every output to which partition_id and which offset self._to_partition_and_offset: List[ValPosition] = [] - + def add(self, partition_id, offset): val_pos = ValPosition(partition_id, offset) self._to_partition_and_offset.append(val_pos) - + def get(self): return self._to_partition_and_offset - + def __str__(self) -> str: res = '' res += '->(' @@ -51,27 +56,29 @@ def __str__(self) -> str: res += f'{val_pos},' res += ')' return res - + def __repr__(self) -> str: return self.__str__() + class Partition(object): + def __init__(self) -> None: self._input_vals: List[PartitionInputVal] = [] self._output_vals: List[PartitionOutputVal] = [] - + def add_input_val(self, input_val: PartitionInputVal): self._input_vals.append(input_val) - + def add_output_val(self, output_val: PartitionOutputVal): self._output_vals.append(output_val) - + def get_input_vals(self): return self._input_vals - + def get_output_vals(self): return self._output_vals - + # get the output offsets sent to dst_partition_id def get_output_offsets(self, dst_partition_id): res = [] @@ -80,9 +87,9 @@ def get_output_offsets(self, dst_partition_id): for val_pos in outputs: if val_pos.partition_id == dst_partition_id: res.append(offset) - + return res - + # get all input dst partition_ids def get_input_partition_ids(self): res = [] @@ -91,7 +98,7 @@ def get_input_partition_ids(self): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + # get all output dst partition_ids def get_output_partition_ids(self): res = [] @@ -101,24 +108,25 @@ def get_output_partition_ids(self): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + def __str__(self) -> str: res = '' res += f' input:\n' res += f' length:{len(self._input_vals)}\n' for i, input_val in enumerate(self._input_vals): res += f' offset={i}:{input_val}\n' - + res += f' output:\n' res += f' length:{len(self._output_vals)}\n' for i, output_val in enumerate(self._output_vals): res += f' offset={i}:{output_val}\n' - + return res - + def __repr__(self) -> str: return self.__str__() + # This class is a middleware between partition splitter # and Pipeline Scheduler. It records the graph info about # partition input/output and provides it to scheduler. @@ -132,42 +140,43 @@ def __repr__(self) -> str: # _input_partition_id: the key represents input_partition # _output_partition_id: the key represents output_partition class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: self._partitions: Dict[int, Partition] = {} self._input_partition_id = input_partition_id self._output_partition_id = output_partition_id - + def set_input_partition_id(self, partition_id: int): self._input_partition_id = partition_id - + def set_output_partition_id(self, partition_id: int): self._output_partition_id = partition_id - + def get_input_partition_id(self): return self._input_partition_id - + def get_output_partition_id(self): return self._output_partition_id - + def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition - + def get_mid_partitions(self): - res = {} #{partition_id: Partition} + res = {} #{partition_id: Partition} for partition_id, partition in self._partitions.items(): if self._input_partition_id == partition_id or self._output_partition_id == partition_id: continue res[partition_id] = partition return res - + def get_mid_partition_ids(self): return list(self.get_mid_partitions().keys()) - + def get_input_partition(self): if self._input_partition_id is not None: return self._partitions[self._input_partition_id] return None - + def get_output_partition(self): if self._output_partition_id is not None: return self._partitions[self._output_partition_id] @@ -175,7 +184,7 @@ def get_output_partition(self): def get_partition_by_id(self, partition_id): return self._partitions[partition_id] - + def __str__(self) -> str: res = '' if len(self._partitions) == 0: @@ -186,21 +195,20 @@ def __str__(self) -> str: res += '{\n' res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' res += '}\n' - + mid_parts = self.get_mid_partitions() for i, (partition_id, part) in enumerate(mid_parts.items()): res += '{\n' res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' res += '}\n' - + output_part = self.get_output_partition() if output_part is not None: res += '{\n' res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' res += '}\n' - + return res - + def __repr__(self) -> str: return self.__str__() - \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/legacy/pipeline/pipelinable.py similarity index 93% rename from colossalai/pipeline/pipelinable.py rename to colossalai/legacy/pipeline/pipelinable.py index ba8b1591da9d..e74cad0ad1b0 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/legacy/pipeline/pipelinable.py @@ -1,20 +1,16 @@ -import inspect - import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from .layer_spec import LayerSpec from .utils import ( - build_kwargs_for_function, build_kwargs_for_module, call_module, customized_partition, - exec_func_with_kwargs, exec_funcs_with_kwargs, partition_balanced, partition_uniform, @@ -135,8 +131,10 @@ def to_layer_list(self, exec_seq=None): children_name = [] for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] - if layer_spec.typename in (torch.nn.modules.container.ModuleList, - torch.nn.modules.container.Sequential): + if layer_spec.typename in ( + torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential, + ): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) for name, module in self._model.named_modules(): @@ -155,9 +153,11 @@ def to_layer_list(self, exec_seq=None): named_modules = dict(self._model.named_modules()) for index, element in enumerate(exec_seq): if isinstance(element, str): - if element == 'SPLIT_NODE': + if element == "SPLIT_NODE": continue - assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' + assert ( + element in named_modules + ), f"Found invalid module name {element}, please check if you spell the module name correctly." # get the layer spec based on the module ID module = named_modules[element] @@ -198,11 +198,12 @@ def partition(self, num_chunks, pipeline_size, rank): param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] elif self._policy == "customized": - assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.' + assert (self._exec_seq + is not None), f"An explicit exec_seq must be defined by user in customized policy mode." self.customized_parts = customized_partition(self._exec_seq) assert len(self.customized_parts) == gpc.get_world_size( ParallelMode.PIPELINE - ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}' + ), f"World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}" parts = self.customized_parts[rank] else: raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].") @@ -241,7 +242,6 @@ def __init__(self, module_list, front_func_dict, behind_func_dict): def forward(self, *input_tensor, **kwargs): for module in self._module_list: - if id(module) in self._front_func_dict: input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py similarity index 98% rename from colossalai/pipeline/pipeline_process_group.py rename to colossalai/legacy/pipeline/pipeline_process_group.py index c61d97ebabfa..1168158defaf 100644 --- a/colossalai/pipeline/pipeline_process_group.py +++ b/colossalai/legacy/pipeline/pipeline_process_group.py @@ -1,11 +1,11 @@ -from typing import List, Dict, Tuple import os import threading +from typing import Dict, List, Tuple -from torch.distributed import rpc import torch.distributed as dist +from torch.distributed import rpc -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup class PipelineProcessGroup: diff --git a/colossalai/legacy/pipeline/rpc/__init__.py b/colossalai/legacy/pipeline/rpc/__init__.py new file mode 100644 index 000000000000..15b65a4138a8 --- /dev/null +++ b/colossalai/legacy/pipeline/rpc/__init__.py @@ -0,0 +1,4 @@ +from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine +from .utils import pytree_map + +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/legacy/pipeline/rpc/_pipeline_base.py similarity index 99% rename from colossalai/pipeline/rpc/_pipeline_base.py rename to colossalai/legacy/pipeline/rpc/_pipeline_base.py index 9e549df58214..88ddb9e98eb2 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_base.py @@ -12,9 +12,9 @@ from torch._C._distributed_rpc import PyRRef from torch.futures import Future -from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import ( +from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc.utils import ( get_batch_lengths, pyobj_map, pytree_filter, diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py similarity index 97% rename from colossalai/pipeline/rpc/_pipeline_schedule.py rename to colossalai/legacy/pipeline/rpc/_pipeline_schedule.py index 6eda8f3b34b7..f53a4835edf2 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py @@ -6,8 +6,8 @@ from torch._C._distributed_rpc import PyRRef from torch.futures import Future -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem # Implementation of different Pipeline schedule # Worker defines the worker for each stage @@ -78,7 +78,7 @@ def _get_work_item_key(self) -> UniqueKey: # 1. forward times reach actual_stage_num, this is the end of continuous forward # 2. forward times reach num_microbatches, this is the end of 1F1B mode if not is_last_stage and \ - target_key.phase == Phase.FORWARD: + target_key.phase == Phase.FORWARD: if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 outstanding_min = actual_stage_num - pp_rank - 1 @@ -144,7 +144,7 @@ def _get_work_item_key(self) -> UniqueKey: forward_block_num = self.forward_times // forward_block_size if self.forward_times >= real_microbatch_num or \ - ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): + ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): target_phase = Phase.BACKWARD target_microbatch_id = self.backward_times else: # others diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py similarity index 98% rename from colossalai/pipeline/rpc/utils.py rename to colossalai/legacy/pipeline/rpc/utils.py index 06e6d976d771..d1033fbde920 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/legacy/pipeline/rpc/utils.py @@ -10,7 +10,7 @@ from torch.futures import Future from colossalai.initialize import launch -from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.pipeline_process_group import ppg def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: diff --git a/colossalai/pipeline/utils.py b/colossalai/legacy/pipeline/utils.py similarity index 100% rename from colossalai/pipeline/utils.py rename to colossalai/legacy/pipeline/utils.py diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py new file mode 100644 index 000000000000..d3278bf1e420 --- /dev/null +++ b/colossalai/legacy/tensor/__init__.py @@ -0,0 +1,17 @@ +from . import distspec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec + +__all__ = [ + 'ComputePattern', + 'ComputeSpec', + 'distspec', + 'DistSpecManager', + 'ProcessGroup', + 'ColoTensorSpec', + 'ShardSpec', + 'ReplicaSpec', +] diff --git a/colossalai/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py similarity index 100% rename from colossalai/tensor/compute_spec.py rename to colossalai/legacy/tensor/compute_spec.py diff --git a/colossalai/tensor/const.py b/colossalai/legacy/tensor/const.py similarity index 100% rename from colossalai/tensor/const.py rename to colossalai/legacy/tensor/const.py diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py similarity index 97% rename from colossalai/tensor/dist_spec_mgr.py rename to colossalai/legacy/tensor/dist_spec_mgr.py index 4740a316b7f5..d97308b04bef 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/legacy/tensor/dist_spec_mgr.py @@ -4,12 +4,12 @@ import torch.distributed as dist from numpy import prod -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec -from colossalai.tensor.process_group import ProcessGroup +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup # TODO(jiaruifang) circle import, move the divide to colossalai.commons. -# colossalai.tensor shall not import any submodule from colossal.nn +# colossalai.legacy.tensor shall not import any submodule from colossal.nn def divide(numerator, denominator): """Only allow exact division. diff --git a/colossalai/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py similarity index 100% rename from colossalai/tensor/distspec.py rename to colossalai/legacy/tensor/distspec.py diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/legacy/tensor/op_wrapper.py similarity index 97% rename from colossalai/tensor/op_wrapper.py rename to colossalai/legacy/tensor/op_wrapper.py index 1c00066f7465..63ebaa264279 100644 --- a/colossalai/tensor/op_wrapper.py +++ b/colossalai/legacy/tensor/op_wrapper.py @@ -1,8 +1,5 @@ -from typing import ( - Callable, - Dict, -) import functools +from typing import Callable, Dict # Custom sharded ops _COLOSSAL_OPS: Dict[str, Callable] = {} diff --git a/colossalai/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py similarity index 100% rename from colossalai/tensor/process_group.py rename to colossalai/legacy/tensor/process_group.py diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py similarity index 79% rename from colossalai/tensor/tensor_spec.py rename to colossalai/legacy/tensor/tensor_spec.py index 580df9f8f310..aa792e507639 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -1,8 +1,8 @@ from dataclasses import dataclass from typing import Optional -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec -from colossalai.tensor.process_group import ProcessGroup +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup from .compute_spec import ComputeSpec diff --git a/colossalai/legacy/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py index 1847e56222a1..1cb99fcc90ed 100644 --- a/colossalai/legacy/trainer/_trainer.py +++ b/colossalai/legacy/trainer/_trainer.py @@ -6,8 +6,9 @@ from colossalai.legacy.engine import Engine from colossalai.legacy.trainer.hooks import BaseHook +from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0 from colossalai.logging import DistributedLogger -from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0 +from colossalai.utils import MultiTimer class Trainer: diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py index 6b150d29139f..cda10030bf65 100644 --- a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py @@ -4,8 +4,8 @@ from colossalai.legacy.registry import HOOKS from colossalai.legacy.trainer.hooks import BaseHook +from colossalai.legacy.utils.checkpointing import save_checkpoint from colossalai.logging import get_dist_logger -from colossalai.utils.checkpointing import save_checkpoint from ._lr_scheduler_hook import LRSchedulerHook diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py index 7d9ad19aa9e9..b1a398ce7f71 100644 --- a/colossalai/legacy/trainer/hooks/_log_hook.py +++ b/colossalai/legacy/trainer/hooks/_log_hook.py @@ -5,12 +5,13 @@ import os.path as osp from typing import List -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric +from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage from colossalai.logging import DistributedLogger -from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage +from colossalai.utils import MultiTimer from ._base_hook import BaseHook from ._commons_ import _format_number @@ -112,8 +113,8 @@ class TensorboardHook(BaseHook): Args: log_dir (str): Directory of log. ranks (list): Ranks of processors. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, - defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, + defaults to colossalai.legacy.context.parallel_mode.ParallelMode.GLOBAL. priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, defaults to 10. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index f1bd19387cb5..899e4d08a5c9 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -7,11 +7,12 @@ import torch import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.legacy.communication import all_reduce +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS -from colossalai.utils import get_current_device, is_no_pp_or_last_stage +from colossalai.legacy.utils import is_no_pp_or_last_stage +from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number diff --git a/colossalai/legacy/utils/__init__.py b/colossalai/legacy/utils/__init__.py new file mode 100644 index 000000000000..ae358f8bebcb --- /dev/null +++ b/colossalai/legacy/utils/__init__.py @@ -0,0 +1,53 @@ +from .checkpointing import load_checkpoint, save_checkpoint +from .common import ( + clip_grad_norm_fp32, + copy_tensor_parallel_attributes, + count_zeros_fp32, + is_dp_rank_0, + is_model_parallel_parameter, + is_no_pp_or_last_stage, + is_tp_rank_0, + is_using_ddp, + is_using_pp, + is_using_sequence, + param_is_not_tensor_parallel_duplicate, + print_rank_0, + switch_virtual_pipeline_parallel_rank, + sync_model_param, +) +from .data_sampler import DataParallelSampler, get_dataloader +from .memory import ( + colo_device_memory_capacity, + colo_device_memory_used, + colo_get_cpu_memory_capacity, + colo_set_cpu_memory_capacity, + colo_set_process_memory_fraction, + report_memory_usage, +) + +__all__ = [ + 'DataParallelSampler', + 'get_dataloader', + 'save_checkpoint', + 'load_checkpoint', + 'colo_device_memory_capacity', + 'colo_device_memory_used', + 'colo_get_cpu_memory_capacity', + 'colo_set_cpu_memory_capacity', + 'colo_set_process_memory_fraction', + 'report_memory_usage', + 'clip_grad_norm_fp32', + 'copy_tensor_parallel_attributes', + 'count_zeros_fp32', + 'is_dp_rank_0', + 'is_model_parallel_parameter', + 'is_no_pp_or_last_stage', + 'is_tp_rank_0', + 'is_using_ddp', + 'is_using_pp', + 'is_using_sequence', + 'param_is_not_tensor_parallel_duplicate', + 'print_rank_0', + 'switch_virtual_pipeline_parallel_rank', + 'sync_model_param', +] diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py similarity index 95% rename from colossalai/utils/activation_checkpoint.py rename to colossalai/legacy/utils/activation_checkpoint.py index fa9ed827a8a7..add690f28cc0 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import weakref + import torch from torch.utils.checkpoint import check_backward_validity, detach_variable -from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states -from .cuda import get_current_device - -import weakref +from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states +from colossalai.utils import get_current_device def copy_to_device(obj, device): @@ -143,7 +143,7 @@ def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): Args: function: Describe the forward pass function. It should know how to handle the input tuples. - activation_offload: The variable to check whether we should offload activation to cpu + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there might be more flexibility for user to define there checkpoint function @@ -227,12 +227,12 @@ def inner_unpack(packed): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: with torch.enable_grad(), \ - torch.cuda.amp.autocast(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) else: with torch.enable_grad(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) if x not in storage: diff --git a/colossalai/legacy/utils/checkpoint/__init__.py b/colossalai/legacy/utils/checkpoint/__init__.py new file mode 100644 index 000000000000..558a956b31ac --- /dev/null +++ b/colossalai/legacy/utils/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .module_checkpoint import load_checkpoint, save_checkpoint + +__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/legacy/utils/checkpoint/module_checkpoint.py similarity index 90% rename from colossalai/utils/checkpoint/module_checkpoint.py rename to colossalai/legacy/utils/checkpoint/module_checkpoint.py index d390da864cd3..9bd2907abf9d 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/legacy/utils/checkpoint/module_checkpoint.py @@ -1,25 +1,28 @@ +from typing import Dict, Optional + import torch import torch.distributed as dist + +from colossalai.interface import OptimizerWrapper from colossalai.tensor import ColoTensor -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from typing import Optional, Dict + +from .utils import gather_tensor, scatter_tensor def save_checkpoint(path: str, epoch: int, model: torch.nn.Module, - optimizer: Optional[ColossalaiOptimizer] = None, + optimizer: Optional[OptimizerWrapper] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): - """save_checkpoint + """save_checkpoint save a model, whose parameters are `ColoTensor`s. Args: path (str): directory to save the checkpoint files. epoch (int): the number of epoch model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + optimizer (OptimizerWrapper, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ rank = dist.get_rank() @@ -74,17 +77,17 @@ def save_checkpoint(path: str, def load_checkpoint(path: str, epoch: int, model: torch.nn.Module, - optimizer: Optional[ColossalaiOptimizer] = None, + optimizer: Optional[OptimizerWrapper] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, torch_load_kwargs: Optional[Dict] = None, load_state_dict_kwargs: Optional[Dict] = None): - """load_checkpoint + """load_checkpoint load a model, whose parameters are `ColoTensor`s. Args: path (str): directory to save the checkpoint files. epoch (int): the number of epoch model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + optimizer (OptimizerWrapper, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/legacy/utils/checkpoint/utils.py similarity index 91% rename from colossalai/utils/checkpoint/utils.py rename to colossalai/legacy/utils/checkpoint/utils.py index 682cd0903d5b..c830d4811463 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/legacy/utils/checkpoint/utils.py @@ -1,63 +1,65 @@ -import torch -import torch.distributed as dist -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern - - -def robust_broadcast(tensor): - with torch.no_grad(): - is_cpu_ten = tensor.device.type == 'cpu' - if is_cpu_ten: - b_data = tensor.cuda() - else: - b_data = tensor - - dist.broadcast(b_data, 0) - - if is_cpu_ten: - tensor.copy_(b_data) - - -def gather_tensor(colo_tensor: ColoTensor) -> None: - """Make colo_tensor replicated when the rank is 0 - """ - if not colo_tensor.is_replicate(): - pg = colo_tensor.get_process_group() - # for the group which contains rank 0 - if pg.dp_local_rank() == 0: - old_dist_spec = colo_tensor.dist_spec - colo_tensor.to_replicate_() - if dist.get_rank() != 0: - colo_tensor.set_dist_spec(old_dist_spec) - - # synchronize all processes for unexpected problems - dist.barrier() - - if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signature - - -def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: - """Reversal operation of `gather_tensor`. - """ - if dist_spec.placement == DistPlacementPattern.REPLICATE: - robust_broadcast(colo_tensor.data) - else: - global_size = colo_tensor.size_global() - - if dist.get_rank() == 0: - entire_data = colo_tensor.data - else: - entire_data = torch.empty(global_size, device=colo_tensor.device) - robust_broadcast(entire_data) - - if dist.get_rank() == 0: - colo_tensor.set_dist_spec(dist_spec) - else: - rep_tensor = ColoTensor( - entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) - rep_tensor.set_dist_spec(dist_spec) - with torch.no_grad(): - colo_tensor.data.copy_(rep_tensor.data) - # synchronize all processes for unexpected problems - dist.barrier() +import torch +import torch.distributed as dist + +from colossalai.legacy.tensor import ColoTensorSpec +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor import ColoTensor + + +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == 'cpu' + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0 + """ + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.dp_local_rank() == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, 'save_ready', True) # set saving signature + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`. + """ + if dist_spec.placement == DistPlacementPattern.REPLICATE: + robust_broadcast(colo_tensor.data) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + robust_broadcast(entire_data) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/colossalai/utils/checkpointing.py b/colossalai/legacy/utils/checkpointing.py similarity index 98% rename from colossalai/utils/checkpointing.py rename to colossalai/legacy/utils/checkpointing.py index d1c6b6370ede..b7b29cc984d6 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/legacy/utils/checkpointing.py @@ -3,9 +3,11 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.constants import IS_TENSOR_PARALLEL + +from colossalai.legacy.constants import IS_TENSOR_PARALLEL +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py new file mode 100644 index 000000000000..35095161c2f2 --- /dev/null +++ b/colossalai/legacy/utils/common.py @@ -0,0 +1,434 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from collections import defaultdict +from contextlib import contextmanager +from typing import Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import inf +from torch.nn.parameter import Parameter + +from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter +from colossalai.utils.multi_tensor_apply import multi_tensor_applier + +try: + from colossalai._C import fused_optim +except: + fused_optim = None + + +def print_rank_0(msg: str, logger=None): + """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. + + Args: + msg (str): A string message to output. + logger (:class:`colossalai.logging.DistributedLogger`, optional): + The logger to record the message, defaults to None. + """ + if gpc.get_global_rank() == 0: + if logger is None: + print(msg, flush=True) + else: + logger.info(msg) + + +def sync_model_param(model, parallel_mode): + r"""Make sure data parameters are consistent during Data Parallel Mode. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + for param in model.parameters(): + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) + + +def is_dp_rank_0(): + return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) + + +def is_tp_rank_0(): + return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) + + +def is_no_pp_or_last_stage(): + return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def is_using_ddp(): + return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 + + +def is_using_pp(): + return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 + + +def is_using_sequence(): + return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 + + +class model_branch_context(object): + + def __enter__(self): + self.env_status = env.save() + + def __exit__(self, *exc_info): + env.load(**self.env_status) + + +def is_model_parallel_parameter(p): + return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) + + +def _calc_l2_norm(grads): + # we should not + global fused_optim + + if fused_optim is None: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + norm = 0.0 + if len(grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + fused_optim.multi_tensor_l2norm, + dummy_overflow_buf, + [grads], + False # no per-parameter norm + ) + return norm + + +def _calc_lp(grads, norm_type): + norm = 0.0 + for grad in grads: + grad_norm = torch.norm(grad, norm_type) + norm += grad_norm**norm_type + return norm + + +def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + if torch.is_tensor(norm) and norm.device.type != 'cuda': + norm = norm.to(torch.cuda.current_device()) + return norm + + +def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: + if isinstance(norm, float): + norm = torch.Tensor([norm]) + if move_to_cuda: + norm = norm.to(torch.cuda.current_device()) + return norm + + +# ======== Gradient Clipping ========= + + +def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + grads = [p.grad for p in params] + use_cuda_kernel = grads[0].device.type == 'cuda' + if norm_type == inf: + local_lp = max([g.abs().max() for g in grads]) + elif norm_type == 2.0 and use_cuda_kernel: + local_lp = _calc_l2_norm(grads)**norm_type + else: + local_lp = _calc_lp(grads, norm_type) + if isinstance(local_lp, torch.Tensor): + return local_lp.item() + return local_lp + + +def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) + for p in params: + if p.is_replicate(): + buckets[None].append(p) + else: + buckets[p.get_process_group().tp_process_group()].append(p) + total_lp = 0.0 + for group, bucket in buckets.items(): + local_lp = _compute_local_lp(bucket, norm_type) + if group is not None: + local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) + else: + dist.all_reduce(local_lp_tensor, group=group) + local_lp = local_lp_tensor.item() + if norm_type == inf: + total_lp = max(total_lp, local_lp) + else: + total_lp += local_lp + return total_lp + + +def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) + else: + dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) + total_lp = total_lp_tensor.item() + return total_lp + + +def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grad_dtype = None + cpu_grad_params: List[ColoParameter] = [] + cuda_grad_params: List[ColoParameter] = [] + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, ColoParameter) + if grad_dtype is None: + grad_dtype = p.grad.dtype + assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' + if p.grad.device.type == 'cuda': + cuda_grad_params.append(p) + else: + cpu_grad_params.append(p) + norm_type = float(norm_type) + cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) + cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) + if norm_type == inf: + total_lp = max(cpu_lp, cuda_lp) + else: + total_lp = cpu_lp + cuda_lp + return _compute_pp_grad_lp(total_lp, norm_type) + + +def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + total_norm = _compute_grad_lp(parameters, norm_type) + if norm_type != inf: + total_norm = total_norm**(1 / norm_type) + return total_norm + + +def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1.0: + cuda_grads: List[torch.Tensor] = [] + cpu_grads: List[torch.Tensor] = [] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + if p.grad.device.type == 'cuda': + cuda_grads.append(p.grad.detach()) + else: + cpu_grads.append(p.grad.detach()) + if len(cuda_grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], + clip_coef) + for g in cpu_grads: + g.mul_(clip_coef) + + +def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: + total_norm = compute_grad_norm(parameters, norm_type) + _clip_grad_norm(parameters, max_norm, total_norm) + return total_norm + + +def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters whose gradients are in fp32. + + This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and + added functionality to handle model parallel parameters. + + Note: + the gradients are modified in place. + + Args: + parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): + An iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[float, int]): Max norm of the gradients. + norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + float: Total norm of the parameters. + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params: List[Parameter] = [] + has_zero_shared_param: bool = False + for param in parameters: + if param.grad is not None: + # Make sure the grads are in fp32 + assert param.grad.dtype == torch.float, \ + f'expected gradient to be dtype torch.float, but got {param.grad.type()}' + if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: + has_zero_shared_param = True + params.append(param) + + if len(params) == 0: + enable_cuda_kernels = False + else: + enable_cuda_kernels = params[0].grad.device.type == 'cuda' + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + + # Parameters can be on CPU or CUDA + # If parameters are on CPU, disable CUDA kernels + + # Calculate norm. + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in params) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + async_op=False) + if has_zero_shared_param: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.DATA), + async_op=False) + total_norm = total_norm_cuda[0].item() + else: + tensor_parallel_grads = [] + no_tensor_parallel_grads = [] + zero_sharded_grads = [] + for p in params: + if is_model_parallel_parameter(p): + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) + tensor_parallel_grads.append(p.grad.data / reductor) + elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: + zero_sharded_grads.append(p.grad.data) + else: + no_tensor_parallel_grads.append(p.grad.data) + + if norm_type == 2.0 and enable_cuda_kernels: + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type + else: + tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) + no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) + zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) + # If norm is type of float, then we convert them into torch.Tensor. + tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) + zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) + no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) + zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) + + # Sum across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: + dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + # Sum across all zero sharded GPUs + if len(zero_sharded_grads) > 0: + dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) + no_tensor_parallel_norm += zero_sharded_norm + total_norm = tensor_parallel_norm + no_tensor_parallel_norm + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) + total_norm = total_norm**(1.0 / norm_type) + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + if enable_cuda_kernels: + grads = [p.grad.detach() for p in params] + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + else: + for p in params: + p.grad.detach().mul_(clip_coeff) + return total_norm + + +def count_zeros_fp32(parameters): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = 0.0 + for param in parameters: + grad_not_none = param.grad is not None + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() + + # Sum across all model-parallel GPUs. + ops = [] + ops.append( + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) + if gpc.is_initialized(ParallelMode.PIPELINE): + ops.append( + dist.all_reduce(total_num_zeros, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PIPELINE), + async_op=True)) + + for req in ops: + req.wait() + total_num_zeros = total_num_zeros.item() + + return total_num_zeros + + +def copy_tensor_parallel_attributes(src_tensor, dst_tensor): + for attr in TENSOR_PARALLEL_ATTRIBUTES: + if hasattr(src_tensor, attr): + val = getattr(src_tensor, attr) + setattr(dst_tensor, attr, val) + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( + ParallelMode.TENSOR) == 0) + + +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/legacy/utils/data_sampler/__init__.py similarity index 100% rename from colossalai/utils/data_sampler/__init__.py rename to colossalai/legacy/utils/data_sampler/__init__.py diff --git a/colossalai/utils/data_sampler/base_sampler.py b/colossalai/legacy/utils/data_sampler/base_sampler.py similarity index 100% rename from colossalai/utils/data_sampler/base_sampler.py rename to colossalai/legacy/utils/data_sampler/base_sampler.py diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py similarity index 98% rename from colossalai/utils/data_sampler/data_parallel_sampler.py rename to colossalai/legacy/utils/data_sampler/data_parallel_sampler.py index 881ddde78648..66a5fdd3694d 100644 --- a/colossalai/utils/data_sampler/data_parallel_sampler.py +++ b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py @@ -10,8 +10,8 @@ import torch from torch.utils.data import DataLoader, Dataset, Sampler -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc T_co = TypeVar('T_co', covariant=True) diff --git a/colossalai/utils/memory.py b/colossalai/legacy/utils/memory.py similarity index 95% rename from colossalai/utils/memory.py rename to colossalai/legacy/utils/memory.py index 434e90edd3b9..360bf0da4a77 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -1,15 +1,15 @@ -import torch import gc -import psutil from collections import namedtuple -from colossalai.context.parallel_mode import ParallelMode -from colossalai.utils import get_current_device -from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode -from colossalai.logging import get_dist_logger +import psutil +import torch +import torch.distributed as dist from packaging import version +from colossalai.legacy.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -68,7 +68,7 @@ def report_memory_usage(message, logger=None, report_cpu=False): Raises: EnvironmentError: Raise error if no distributed environment has been initialized. """ - if not gpc.is_initialized(ParallelMode.GLOBAL): + if not dist.is_initialized(): raise EnvironmentError("No distributed environment is initialized") gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated()) @@ -138,7 +138,7 @@ def colo_device_memory_used(device: torch.device) -> int: def colo_set_process_memory_fraction(ratio: float) -> None: - """colo_set_process_memory_fraction + """colo_set_process_memory_fraction set how much cuda memory used on the gpu belonging to the current process. diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/legacy/utils/profiler/__init__.py similarity index 100% rename from colossalai/utils/profiler/__init__.py rename to colossalai/legacy/utils/profiler/__init__.py diff --git a/colossalai/utils/profiler/extention.py b/colossalai/legacy/utils/profiler/extention.py similarity index 100% rename from colossalai/utils/profiler/extention.py rename to colossalai/legacy/utils/profiler/extention.py diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/legacy/utils/profiler/legacy/__init__.py similarity index 77% rename from colossalai/utils/profiler/legacy/__init__.py rename to colossalai/legacy/utils/profiler/legacy/__init__.py index 849c7fca3053..88beed86d7de 100644 --- a/colossalai/utils/profiler/legacy/__init__.py +++ b/colossalai/legacy/utils/profiler/legacy/__init__.py @@ -1,6 +1,6 @@ -from .comm_profiler import CommProfiler -from .pcie_profiler import PcieProfiler -from .prof_utils import ProfilerContext, BaseProfiler -from .mem_profiler import MemProfiler - -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] +from .comm_profiler import CommProfiler +from .mem_profiler import MemProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import BaseProfiler, ProfilerContext + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py similarity index 96% rename from colossalai/utils/profiler/legacy/comm_profiler.py rename to colossalai/legacy/utils/profiler/legacy/comm_profiler.py index 334f0113ee90..bb7e2654c740 100644 --- a/colossalai/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -1,308 +1,311 @@ -import inspect -from pathlib import Path -from functools import partial -import torch -from torch.autograd.profiler import profile -import torch.distributed as dist -from torch.distributed import ReduceOp -from colossalai.utils import get_current_device -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List, Optional - - -def _get_code_location(depth: int): - ret = [] - length = min(len(inspect.stack()), depth + 1) - for i in range(3, length): - upper_frame = inspect.stack()[i] - function_name = inspect.stack()[i - 1].function - ret.append(upper_frame.filename) - ret.append('(') - ret.append(str(upper_frame.lineno)) - ret.append('): ') - ret.append(function_name) - if i != length - 1: - ret.append('\n') - - return ''.join(ret) - - -torch_all_reduce = dist.all_reduce -torch_all_gather = dist.all_gather -torch_reduce_scatter = dist.reduce_scatter -torch_broadcast = dist.broadcast -torch_reduce = dist.reduce - - -class CommEvent(object): - """Communication Event. Used for communication time and communication - volume recording. - """ - - def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): - self.self_count = count - self.self_comm_vol = comm_vol - self.self_cuda_time = cuda_time - - def add(self, rhs): - self.self_count += rhs.self_count - self.self_comm_vol += rhs.self_comm_vol - self.self_cuda_time += rhs.self_cuda_time - - -class CommProfiler(BaseProfiler): - """Communication profiler. Records all communication events. - """ - - def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): - super().__init__(profiler_name="Collective_Communication", priority=0) - self.depth = 3 + depth - self.total_count = total_count - self.total_comm_vol = total_comm_vol - self.total_cuda_time = total_cuda_time - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def reset(self): - self.total_count = 0 - self.total_comm_vol = 0 - self.total_cuda_time = 0 - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def enable(self): - dist.all_reduce = partial(all_reduce, profiler=self) - dist.all_gather = partial(all_gather, profiler=self) - dist.reduce_scatter = partial(reduce_scatter, profiler=self) - dist.broadcast = partial(broadcast, profiler=self) - dist.reduce = partial(reduce, profiler=self) - - def disable(self): - dist.all_reduce = torch_all_reduce - dist.all_gather = torch_all_gather - dist.reduce_scatter = torch_reduce_scatter - dist.broadcast = torch_broadcast - dist.reduce = torch_reduce - - def to_tensorboard(self, writer): - writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - if self.warn_flag: - append("Warning: there exists multiple communication operations in the same time. As a result, " - "the profiling result is not accurate.") - - if self.total_cuda_time == 0: - return "No collective communication has been called yet!" - - append("Collective communication profiling result:") - append("total cuda time: {}".format(_format_time(self.total_cuda_time))) - append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) - append("total number of calls: {}".format(self.total_count)) - append("All events:") - - separation = '-' * 74 - row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 - - append(separation) - append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) - append(separation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.self_cuda_time), - '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), - _format_memory(event.self_comm_vol), - _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) - append() - - return ''.join(res) - - @property - def has_aync_op(self): - return self.pending_op is not None - - def activate_profiler(self, kn: str, vol: float): - self.pending_metadata = (kn, _get_code_location(self.depth), vol) - self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) - self.profiler.__enter__() - - def close_profiler(self, group=None): - assert self.profiler is not None, "There is no running dist op" - kernel_name, code_location, vol = self.pending_metadata - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled and dist.get_world_size(group) > 1: - assert_flag = 0 - current_comm_event = None - events = self.profiler.function_events - for event in events: - if kernel_name in event.name: - assert assert_flag == 0, "Multiple dist ops has been called " - current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) - assert_flag += 1 - - assert current_comm_event is not None, "dist op has not been found" - - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) - torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) - current_comm_event.self_cuda_time = buffer.item() - - self.total_count += current_comm_event.self_count - self.total_comm_vol += current_comm_event.self_comm_vol - self.total_cuda_time += current_comm_event.self_cuda_time - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - - self.profiler = None - self.pending_op = None - self.pending_metadata = None - - def wait_async_op(self): - if self.pending_op is not None: - op = self.pending_op - op.wait() - self.close_profiler() - - -class CommHandler(object): - """Communication handler. A dummy handler to wait aync operations. - """ - - def __init__(self, profiler: CommProfiler): - super().__init__() - self.prof = profiler - - def wait(self): - self.prof.wait_async_op() - - -def async_check(profiler: CommProfiler): - if profiler.pending_op is not None: - profiler.warn_flag = True - profiler.wait_async_op() - - -def all_reduce(tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = 2 * (comm_size - 1) / comm_size - comm_vol = correction * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) - profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce_scatter(output: torch.Tensor, - input_list: List[torch.Tensor], - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for tensor in input_list: - comm_vol += tensor.element_size() * tensor.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) - profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def all_gather(tensor_list: List[torch.Tensor], - tensor: torch.Tensor, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for ten in tensor_list: - comm_vol += ten.element_size() * ten.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) - profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def broadcast(tensor: torch.Tensor, - src: int, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) - profiler.pending_op = torch_broadcast(tensor, src, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce(tensor: torch.Tensor, - dst: int, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) - profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) +import inspect +from functools import partial +from pathlib import Path +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch.autograd.profiler import profile +from torch.distributed import ReduceOp + +from colossalai.utils import get_current_device + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_code_location(depth: int): + ret = [] + length = min(len(inspect.stack()), depth + 1) + for i in range(3, length): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + ret.append(upper_frame.filename) + ret.append('(') + ret.append(str(upper_frame.lineno)) + ret.append('): ') + ret.append(function_name) + if i != length - 1: + ret.append('\n') + + return ''.join(ret) + + +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(BaseProfiler): + """Communication profiler. Records all communication events. + """ + + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer): + writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + if self.warn_flag: + append("Warning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate.") + + if self.total_cuda_time == 0: + return "No collective communication has been called yet!" + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:") + + separation = '-' * 74 + row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + + append(separation) + append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(separation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.self_cuda_time), + '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + append() + + return ''.join(res) + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled and dist.get_world_size(group) > 1: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations. + """ + + def __init__(self, profiler: CommProfiler): + super().__init__() + self.prof = profiler + + def wait(self): + self.prof.wait_async_op() + + +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() + + +def all_reduce(tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce_scatter(output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def all_gather(tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def broadcast(tensor: torch.Tensor, + src: int, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce(tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py similarity index 95% rename from colossalai/utils/profiler/legacy/pcie_profiler.py rename to colossalai/legacy/utils/profiler/legacy/pcie_profiler.py index 8f812f5cfc7b..514d3c6fabfa 100644 --- a/colossalai/utils/profiler/legacy/pcie_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py @@ -1,148 +1,150 @@ -from pathlib import Path -from torch.autograd.profiler import profile -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List - - -def _get_size(dtype: str): - if dtype == "fp16": - return 2 - elif dtype == "fp32": - return 4 - else: - raise NotImplementedError - - -def _get_numel(my_list: List[int]) -> int: - from functools import reduce - from operator import mul - return reduce(mul, my_list) - - -def _reduce_location(locations: List[str]) -> str: - ret = [] - for lo in locations: - ret.append(lo) - ret.append("\n") - ret = ret[:-1] - return ''.join(ret) - - -class PcieEvent(object): - """Pcie Event. - """ - - def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): - self.count = count - self.pcie_vol = pcie_vol - self.cuda_time = cuda_time - - def add(self, rhs): - self.count += rhs.count - self.pcie_vol += rhs.pcie_vol - self.cuda_time += rhs.cuda_time - - -class PcieProfiler(BaseProfiler): - """Pcie profiler. Records all data transmission between CPU and GPU. - - TODO: Merge pcie profiler into communication profiler - """ - - def __init__(self, dtype: str = "fp32", depth: int = 1): - super().__init__(profiler_name="Pcie", priority=10) - self.depth = depth - self.data_size = _get_size(dtype) - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def reset(self): - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def enable(self): - self.profiler = profile(enabled=True, - use_cuda=True, - use_cpu=True, - use_kineto=True, - record_shapes=True, - with_stack=True) - self.profiler.__enter__() - - def disable(self): - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled: - events = self.profiler.function_events - for event in events: - if event.name == "aten::copy_": - t_shape = event.input_shapes[0] - if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: - continue - current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) - code_location = _reduce_location(event.stack[:self.depth]) - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - elif 'Memcpy HtoD' in event.name: - self.h2d_count += 1 - self.h2d_time += event.cuda_time_total - elif 'Memcpy DtoH' in event.name: - self.d2h_count += 1 - self.d2h_time += event.cuda_time_total - - self.profiler = None - - def to_tensorboard(self, writer): - writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - append("Pcie profiling result:") - append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) - append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) - append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) - append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) - - append("Possible data transmission events in PCIE:") - - separation = '-' * 62 - row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 - - append(separation) - append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) - append(separation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), - _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) - append() - - return ''.join(res) +from pathlib import Path +from typing import List + +from torch.autograd.profiler import profile + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_size(dtype: str): + if dtype == "fp16": + return 2 + elif dtype == "fp32": + return 4 + else: + raise NotImplementedError + + +def _get_numel(my_list: List[int]) -> int: + from functools import reduce + from operator import mul + return reduce(mul, my_list) + + +def _reduce_location(locations: List[str]) -> str: + ret = [] + for lo in locations: + ret.append(lo) + ret.append("\n") + ret = ret[:-1] + return ''.join(ret) + + +class PcieEvent(object): + """Pcie Event. + """ + + def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): + self.count = count + self.pcie_vol = pcie_vol + self.cuda_time = cuda_time + + def add(self, rhs): + self.count += rhs.count + self.pcie_vol += rhs.pcie_vol + self.cuda_time += rhs.cuda_time + + +class PcieProfiler(BaseProfiler): + """Pcie profiler. Records all data transmission between CPU and GPU. + + TODO: Merge pcie profiler into communication profiler + """ + + def __init__(self, dtype: str = "fp32", depth: int = 1): + super().__init__(profiler_name="Pcie", priority=10) + self.depth = depth + self.data_size = _get_size(dtype) + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def reset(self): + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def enable(self): + self.profiler = profile(enabled=True, + use_cuda=True, + use_cpu=True, + use_kineto=True, + record_shapes=True, + with_stack=True) + self.profiler.__enter__() + + def disable(self): + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + events = self.profiler.function_events + for event in events: + if event.name == "aten::copy_": + t_shape = event.input_shapes[0] + if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: + continue + current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) + code_location = _reduce_location(event.stack[:self.depth]) + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + elif 'Memcpy HtoD' in event.name: + self.h2d_count += 1 + self.h2d_time += event.cuda_time_total + elif 'Memcpy DtoH' in event.name: + self.d2h_count += 1 + self.d2h_time += event.cuda_time_total + + self.profiler = None + + def to_tensorboard(self, writer): + writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + append("Pcie profiling result:") + append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) + append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) + append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) + append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) + + append("Possible data transmission events in PCIE:") + + separation = '-' * 62 + row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + + append(separation) + append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(separation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + append() + + return ''.join(res) diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/legacy/utils/profiler/legacy/prof_utils.py similarity index 94% rename from colossalai/utils/profiler/legacy/prof_utils.py rename to colossalai/legacy/utils/profiler/legacy/prof_utils.py index 2f7eee827651..9b948c9ec1cd 100644 --- a/colossalai/utils/profiler/legacy/prof_utils.py +++ b/colossalai/legacy/utils/profiler/legacy/prof_utils.py @@ -1,131 +1,132 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Union, List -from colossalai.core import global_context as gpc - - -# copied from high version pytorch to support low version -def _format_time(time_us): - """Defines how to format time in FunctionEvent""" - US_IN_SECOND = 1000.0 * 1000.0 - US_IN_MS = 1000.0 - if time_us >= US_IN_SECOND: - return '{:.3f}s'.format(time_us / US_IN_SECOND) - if time_us >= US_IN_MS: - return '{:.3f}ms'.format(time_us / US_IN_MS) - return '{:.3f}us'.format(time_us) - - -# copied from high version pytorch to support low version -def _format_memory(nbytes): - """Returns a formatted memory size string""" - KB = 1024 - MB = 1024 * KB - GB = 1024 * MB - if (abs(nbytes) >= GB): - return '{:.2f} GB'.format(nbytes * 1.0 / GB) - elif (abs(nbytes) >= MB): - return '{:.2f} MB'.format(nbytes * 1.0 / MB) - elif (abs(nbytes) >= KB): - return '{:.2f} KB'.format(nbytes * 1.0 / KB) - else: - return str(nbytes) + ' B' - - -def _format_bandwidth(volume: float or int, time_us: int): - sec_div_mb = (1000.0 / 1024.0)**2 - mb_per_sec = volume / time_us * sec_div_mb - - if mb_per_sec >= 1024.0: - return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) - else: - return '{:.3f} MB/s'.format(mb_per_sec) - - -class BaseProfiler(ABC): - - def __init__(self, profiler_name: str, priority: int): - self.name = profiler_name - self.priority = priority - - @abstractmethod - def enable(self): - pass - - @abstractmethod - def disable(self): - pass - - @abstractmethod - def to_tensorboard(self, writer): - pass - - @abstractmethod - def to_file(self, filename: Path): - pass - - @abstractmethod - def show(self): - pass - - -class ProfilerContext(object): - """Profiler context manager - - Usage:: - - world_size = 4 - inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) - outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) - outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) - - cc_prof = CommProfiler() - - with ProfilerContext([cc_prof]) as prof: - op = dist.all_reduce(inputs, async_op=True) - dist.all_gather(outputs_list, inputs) - op.wait() - dist.reduce_scatter(inputs, outputs_list) - dist.broadcast(inputs, 0) - dist.reduce(inputs, 0) - - prof.show() - """ - - def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): - self.enable = enable - self.profilers = sorted(profilers, key=lambda prof: prof.priority) - - def __enter__(self): - if self.enable: - for prof in self.profilers: - prof.enable() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.enable: - for prof in self.profilers: - prof.disable() - - def to_tensorboard(self, writer): - from torch.utils.tensorboard import SummaryWriter - - assert isinstance(writer, SummaryWriter), \ - f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' - - for prof in self.profilers: - prof.to_tensorboard(writer) - - def to_file(self, log_dir: Union[str, Path]): - if isinstance(log_dir, str): - log_dir = Path(log_dir) - - if not log_dir.exists(): - log_dir.mkdir(parents=True, exist_ok=True) - for prof in self.profilers: - log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') - prof.to_file(log_file) - - def show(self): - for prof in self.profilers: - prof.show() +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Union + +from colossalai.legacy.core import global_context as gpc + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} GB'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} MB'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} KB'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' B' + + +def _format_bandwidth(volume: float or int, time_us: int): + sec_div_mb = (1000.0 / 1024.0)**2 + mb_per_sec = volume / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + else: + return '{:.3f} MB/s'.format(mb_per_sec) + + +class BaseProfiler(ABC): + + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """Profiler context manager + + Usage:: + + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + cc_prof = CommProfiler() + + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance(writer, SummaryWriter), \ + f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/legacy/utils/profiler/profiler.py similarity index 97% rename from colossalai/utils/profiler/profiler.py rename to colossalai/legacy/utils/profiler/profiler.py index 3026d723deb0..0827f06b586c 100644 --- a/colossalai/utils/profiler/profiler.py +++ b/colossalai/legacy/utils/profiler/profiler.py @@ -9,9 +9,9 @@ from torch.profiler.profiler import ProfilerAction from colossalai.legacy.engine import Engine +from colossalai.legacy.utils.profiler.extention import ProfilerExtension +from colossalai.legacy.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention from colossalai.logging import get_dist_logger -from colossalai.utils.profiler.extention import ProfilerExtension -from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention class profile(torch_profile): diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py similarity index 98% rename from colossalai/utils/profiler/stateful_tensor_mem_extention.py rename to colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py index 412bd7277eee..f3bb66ced583 100644 --- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py @@ -9,7 +9,7 @@ from colossalai.gemini.ophooks import BaseOpHook from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.legacy.engine import Engine -from colossalai.utils.profiler.extention import ProfilerExtension +from colossalai.legacy.utils.profiler.extention import ProfilerExtension class DeviceType(Enum): diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/legacy/zero/__init__.py similarity index 100% rename from colossalai/zero/legacy/__init__.py rename to colossalai/legacy/zero/__init__.py diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py similarity index 100% rename from colossalai/zero/legacy/gemini/__init__.py rename to colossalai/legacy/zero/gemini/__init__.py diff --git a/colossalai/zero/legacy/gemini/gemini_context.py b/colossalai/legacy/zero/gemini/gemini_context.py similarity index 100% rename from colossalai/zero/legacy/gemini/gemini_context.py rename to colossalai/legacy/zero/gemini/gemini_context.py diff --git a/colossalai/zero/legacy/gemini/ophooks/__init__.py b/colossalai/legacy/zero/gemini/ophooks/__init__.py similarity index 100% rename from colossalai/zero/legacy/gemini/ophooks/__init__.py rename to colossalai/legacy/zero/gemini/ophooks/__init__.py diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py similarity index 100% rename from colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py rename to colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py similarity index 100% rename from colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py rename to colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py diff --git a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py similarity index 98% rename from colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py rename to colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py index f40d6ced1ee0..eebcf86e0e58 100644 --- a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py @@ -5,9 +5,9 @@ import torch +from colossalai.legacy.zero.gemini.tensor_utils import alloc_storage, free_storage from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor -from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage class TrainingPhase(Enum): diff --git a/colossalai/zero/legacy/gemini/ophooks/utils.py b/colossalai/legacy/zero/gemini/ophooks/utils.py similarity index 100% rename from colossalai/zero/legacy/gemini/ophooks/utils.py rename to colossalai/legacy/zero/gemini/ophooks/utils.py diff --git a/colossalai/zero/legacy/gemini/paramhooks/__init__.py b/colossalai/legacy/zero/gemini/paramhooks/__init__.py similarity index 100% rename from colossalai/zero/legacy/gemini/paramhooks/__init__.py rename to colossalai/legacy/zero/gemini/paramhooks/__init__.py diff --git a/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py similarity index 100% rename from colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py rename to colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py diff --git a/colossalai/zero/legacy/gemini/stateful_tensor.py b/colossalai/legacy/zero/gemini/stateful_tensor.py similarity index 100% rename from colossalai/zero/legacy/gemini/stateful_tensor.py rename to colossalai/legacy/zero/gemini/stateful_tensor.py diff --git a/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py similarity index 100% rename from colossalai/zero/legacy/gemini/stateful_tensor_mgr.py rename to colossalai/legacy/zero/gemini/stateful_tensor_mgr.py diff --git a/colossalai/zero/legacy/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py similarity index 98% rename from colossalai/zero/legacy/gemini/tensor_placement_policy.py rename to colossalai/legacy/zero/gemini/tensor_placement_policy.py index 165ae51fee60..275933ec2cfb 100644 --- a/colossalai/zero/legacy/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,8 +5,8 @@ import torch +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor diff --git a/colossalai/zero/legacy/gemini/tensor_utils.py b/colossalai/legacy/zero/gemini/tensor_utils.py similarity index 100% rename from colossalai/zero/legacy/gemini/tensor_utils.py rename to colossalai/legacy/zero/gemini/tensor_utils.py diff --git a/colossalai/zero/legacy/init_ctx/__init__.py b/colossalai/legacy/zero/init_ctx/__init__.py similarity index 100% rename from colossalai/zero/legacy/init_ctx/__init__.py rename to colossalai/legacy/zero/init_ctx/__init__.py diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/legacy/zero/init_ctx/init_context.py similarity index 96% rename from colossalai/zero/legacy/init_ctx/init_context.py rename to colossalai/legacy/zero/init_ctx/init_context.py index 84e2d2f4f8e1..4a7e46408583 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/legacy/zero/init_ctx/init_context.py @@ -8,15 +8,15 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 +from colossalai.legacy.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.legacy.zero.sharded_param import ShardedParamV2 from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.legacy.sharded_param import ShardedParamV2 @dataclass diff --git a/colossalai/zero/legacy/shard_utils/__init__.py b/colossalai/legacy/zero/shard_utils/__init__.py similarity index 100% rename from colossalai/zero/legacy/shard_utils/__init__.py rename to colossalai/legacy/zero/shard_utils/__init__.py diff --git a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py similarity index 90% rename from colossalai/zero/legacy/shard_utils/base_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/base_shard_strategy.py index 7ca951091640..9fb80f57ae77 100644 --- a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py @@ -3,7 +3,7 @@ import torch.distributed as dist -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor class BaseShardStrategy(ABC): diff --git a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py similarity index 97% rename from colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index d663104831ce..1f7baad57816 100644 --- a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils import get_current_device -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor from .tensor_shard_strategy import TensorShardStrategy diff --git a/colossalai/zero/legacy/shard_utils/commons.py b/colossalai/legacy/zero/shard_utils/commons.py similarity index 100% rename from colossalai/zero/legacy/shard_utils/commons.py rename to colossalai/legacy/zero/shard_utils/commons.py diff --git a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py similarity index 90% rename from colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index d1df4803b820..cc43907f6655 100644 --- a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ import torch import torch.distributed as dist +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.shard_utils.commons import get_shard +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils import get_current_device -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.shard_utils.commons import get_shard -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class TensorShardStrategy(BaseShardStrategy): diff --git a/colossalai/zero/legacy/sharded_model/__init__.py b/colossalai/legacy/zero/sharded_model/__init__.py similarity index 100% rename from colossalai/zero/legacy/sharded_model/__init__.py rename to colossalai/legacy/zero/sharded_model/__init__.py diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/legacy/zero/sharded_model/_utils.py similarity index 97% rename from colossalai/zero/legacy/sharded_model/_utils.py rename to colossalai/legacy/zero/sharded_model/_utils.py index f1d642cf3f13..b8a618ef5a0d 100644 --- a/colossalai/zero/legacy/sharded_model/_utils.py +++ b/colossalai/legacy/zero/sharded_model/_utils.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor def get_gradient_predivide_factor(world_size: int) -> float: diff --git a/colossalai/zero/legacy/sharded_model/reduce_scatter.py b/colossalai/legacy/zero/sharded_model/reduce_scatter.py similarity index 100% rename from colossalai/zero/legacy/sharded_model/reduce_scatter.py rename to colossalai/legacy/zero/sharded_model/reduce_scatter.py diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py similarity index 97% rename from colossalai/zero/legacy/sharded_model/sharded_model_v2.py rename to colossalai/legacy/zero/sharded_model/sharded_model_v2.py index e7064277fb3c..91c21ccf9516 100644 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -11,20 +11,20 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils.memory import colo_device_memory_capacity +from colossalai.legacy.zero.gemini.ophooks import register_ophooks_recursively +from colossalai.legacy.zero.gemini.paramhooks import BaseParamHookMgr +from colossalai.legacy.zero.gemini.stateful_tensor import TensorState +from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.legacy.zero.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.logging import get_dist_logger from colossalai.utils import disposable, get_current_device -from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector -from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively -from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr -from colossalai.zero.legacy.gemini.stateful_tensor import TensorState -from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer +from colossalai.zero.gemini.memory_tracer import MemStatsCollector from ._utils import ( cast_float_arguments, diff --git a/colossalai/zero/legacy/sharded_model/utils.py b/colossalai/legacy/zero/sharded_model/utils.py similarity index 92% rename from colossalai/zero/legacy/sharded_model/utils.py rename to colossalai/legacy/zero/sharded_model/utils.py index 08806e78ea3b..7a411669900b 100644 --- a/colossalai/zero/legacy/sharded_model/utils.py +++ b/colossalai/legacy/zero/sharded_model/utils.py @@ -2,7 +2,7 @@ import torch -from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.legacy.zero.sharded_model import ShardedModelV2 def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py similarity index 94% rename from colossalai/zero/legacy/sharded_model/zero_hook.py rename to colossalai/legacy/zero/sharded_model/zero_hook.py index 1815bee3a9e0..3fc373e5ca44 100644 --- a/colossalai/zero/legacy/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -4,13 +4,13 @@ import torch.distributed as dist from colossalai.legacy.registry import OPHOOKS +from colossalai.legacy.zero.gemini.ophooks import BaseOpHook +from colossalai.legacy.zero.gemini.stateful_tensor import TensorState +from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector -from colossalai.zero.legacy.gemini.ophooks import BaseOpHook -from colossalai.zero.legacy.gemini.stateful_tensor import TensorState -from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.legacy.shard_utils import BaseShardStrategy @OPHOOKS.register_module diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/legacy/zero/sharded_optim/__init__.py similarity index 100% rename from colossalai/zero/legacy/sharded_optim/__init__.py rename to colossalai/legacy/zero/sharded_optim/__init__.py diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py similarity index 97% rename from colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py rename to colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py index 41dd174cb65a..e21f1cea04df 100644 --- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py +++ b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py @@ -12,15 +12,15 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.legacy.zero.sharded_model import ShardedModelV2 +from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32 class OptimState(Enum): @@ -28,7 +28,7 @@ class OptimState(Enum): UNSCALED = 2 -class ShardedOptimizerV2(ColossalaiOptimizer): +class ShardedOptimizerV2(OptimizerWrapper): """A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO). By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/legacy/zero/sharded_param/__init__.py similarity index 100% rename from colossalai/zero/legacy/sharded_param/__init__.py rename to colossalai/legacy/zero/sharded_param/__init__.py diff --git a/colossalai/zero/legacy/sharded_param/sharded_param.py b/colossalai/legacy/zero/sharded_param/sharded_param.py similarity index 96% rename from colossalai/zero/legacy/sharded_param/sharded_param.py rename to colossalai/legacy/zero/sharded_param/sharded_param.py index 4bcc4b62104a..454a722cf7e7 100644 --- a/colossalai/zero/legacy/sharded_param/sharded_param.py +++ b/colossalai/legacy/zero/sharded_param/sharded_param.py @@ -2,8 +2,8 @@ import torch -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.tensor_utils import colo_tensor_mem_usage from .sharded_tensor import ShardedTensor diff --git a/colossalai/zero/legacy/sharded_param/sharded_tensor.py b/colossalai/legacy/zero/sharded_param/sharded_tensor.py similarity index 94% rename from colossalai/zero/legacy/sharded_param/sharded_tensor.py rename to colossalai/legacy/zero/sharded_param/sharded_tensor.py index af60312600f2..43c7576b93b5 100644 --- a/colossalai/zero/legacy/sharded_param/sharded_tensor.py +++ b/colossalai/legacy/zero/sharded_param/sharded_tensor.py @@ -1,6 +1,6 @@ import torch -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index f9abe4a2a2b6..fd05ddf1d50f 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -134,8 +134,6 @@ def info(self, message: str, ranks: List[int] = None) -> None: Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) @@ -147,8 +145,6 @@ def warning(self, message: str, ranks: List[int] = None) -> None: Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) @@ -160,8 +156,6 @@ def debug(self, message: str, ranks: List[int] = None) -> None: Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) @@ -173,8 +167,6 @@ def error(self, message: str, ranks: List[int] = None) -> None: Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index edd986ef5e82..9aeab9f44a6d 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,2 +1,2 @@ -from .moe import * +# from .moe import * from .utils import * diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 56b11f4d9e08..712d872bb921 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -6,10 +6,10 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator class MoeExperts(nn.Module): diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 03f55d91f3a8..9293d3208f11 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from colossalai.nn.layer.moe._operation import ( COL_MOE_KERNEL_FLAG, AllGather, @@ -18,7 +19,6 @@ from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator @no_shard_zero_decrator(is_replicated=True) diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index ee2add48ab91..7c6fb099d272 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1 +1 @@ -from .loss_moe import MoeCrossEntropyLoss, MoeLoss +# from .loss_moe import MoeCrossEntropyLoss, MoeLoss diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 06072648beba..7e310793f515 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,10 +1,9 @@ -from .colossalai_optimizer import ColossalaiOptimizer +from .cpu_adam import CPUAdam from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB from .fused_sgd import FusedSGD +from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars -from .cpu_adam import CPUAdam -from .hybrid_adam import HybridAdam -__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] +__all__ = ['FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py deleted file mode 100644 index 34f5a9541975..000000000000 --- a/colossalai/nn/optimizer/colossalai_optimizer.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.nn as nn -from torch import Tensor -from torch.optim import Optimizer -from colossalai.utils import clip_grad_norm_fp32 - - -class ColossalaiOptimizer(Optimizer): - - def __init__(self, optim: Optimizer): - self.optim = optim - - @property - def param_groups(self): - return self.optim.param_groups - - @property - def defaults(self): - return self.optim.defaults - - def add_param_group(self, *args, **kwargs): - return self.optim.add_param_group(*args, **kwargs) - - def step(self, *args, **kwargs): - return self.optim.step(*args, **kwargs) - - def zero_grad(self, *args, **kwargs): - self.optim.zero_grad(*args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - self.optim.load_state_dict(*args, **kwargs) - - def state_dict(self): - return self.optim.state_dict() - - def backward(self, loss: Tensor): - loss.backward() - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - torch.autograd.backward(tensors=tensor, grad_tensors=grad) - - def clip_grad_norm(self, model: nn.Module, max_norm: float): - if max_norm > 0.0: - clip_grad_norm_fp32(model.parameters(), max_norm) diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 0fcde9707646..e88a1f00a1b7 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,4 +1,11 @@ -from .pipelinable import PipelinableContext, PipelinableModel -from .layer_spec import LayerSpec +from .p2p import PipelineP2PCommunication +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .stage_manager import PipelineStageManager -__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file +__all__ = [ + 'PipelineSchedule', + 'OneForwardOneBackwardSchedule', + 'InterleavedSchedule', + 'PipelineP2PCommunication', + 'PipelineStageManager', +] diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py deleted file mode 100644 index 79e19f9eaf77..000000000000 --- a/colossalai/pipeline/middleware/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal - -__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py deleted file mode 100644 index 9d9e9d44f46c..000000000000 --- a/colossalai/pipeline/rpc/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine -from .utils import pytree_map - -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 8b13413b1a31..07c0f5927060 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -1,7 +1,9 @@ from .base import PipelineSchedule +from .interleaved_pp import InterleavedSchedule from .one_f_one_b import OneForwardOneBackwardSchedule __all__ = [ 'PipelineSchedule', 'OneForwardOneBackwardSchedule', + 'InterleavedSchedule', ] diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index b2da64e6c33a..099376d931e8 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,18 +1,11 @@ -from . import distspec from .colo_parameter import ColoParameter from .colo_tensor import ColoTensor from .comm_spec import CollectiveCommPattern, CommSpec -from .compute_spec import ComputePattern, ComputeSpec -from .dist_spec_mgr import DistSpecManager -from .distspec import ReplicaSpec, ShardSpec from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager -from .process_group import ProcessGroup -from .tensor_spec import ColoTensorSpec from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ - 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', - 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook', + 'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list' ] diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 6f9717d353e6..5226f688b43b 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,79 +1,32 @@ -from .activation_checkpoint import checkpoint -from .checkpointing import load_checkpoint, save_checkpoint from .common import ( _cast_float, - clip_grad_norm_fp32, conditional_context, - copy_tensor_parallel_attributes, - count_zeros_fp32, disposable, ensure_path_exists, free_storage, is_ddp_ignored, - is_dp_rank_0, - is_model_parallel_parameter, - is_no_pp_or_last_stage, - is_tp_rank_0, - is_using_ddp, - is_using_pp, - is_using_sequence, - multi_tensor_applier, - param_is_not_tensor_parallel_duplicate, - print_rank_0, - switch_virtual_pipeline_parallel_rank, - sync_model_param, -) -from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize -from .data_sampler import DataParallelSampler, get_dataloader -from .memory import ( - colo_device_memory_capacity, - colo_device_memory_used, - colo_get_cpu_memory_capacity, - colo_set_cpu_memory_capacity, - colo_set_process_memory_fraction, - report_memory_usage, + set_seed, ) +from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize +from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ - 'checkpoint', - 'print_rank_0', - 'sync_model_param', - 'is_ddp_ignored', - 'is_dp_rank_0', - 'is_tp_rank_0', - 'is_no_pp_or_last_stage', - 'is_using_ddp', - 'is_using_pp', - 'is_using_sequence', 'conditional_context', - 'is_model_parallel_parameter', - 'clip_grad_norm_fp32', - 'count_zeros_fp32', - 'copy_tensor_parallel_attributes', - 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', - 'report_memory_usage', - 'colo_device_memory_capacity', - 'colo_device_memory_used', - 'colo_set_process_memory_fraction', 'Timer', 'MultiTimer', 'multi_tensor_applier', - 'DataParallelSampler', - 'get_dataloader', - 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', - 'load_checkpoint', - 'save_checkpoint', 'ensure_path_exists', 'disposable', - 'colo_set_cpu_memory_capacity', - 'colo_get_cpu_memory_capacity', '_cast_float', 'free_storage', + 'set_seed', + 'is_ddp_ignored', + 'set_device', ] diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py deleted file mode 100644 index 1795b4ce36f4..000000000000 --- a/colossalai/utils/checkpoint/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .module_checkpoint import save_checkpoint, load_checkpoint - -__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 998901708239..8c769c5b13c0 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -3,44 +3,12 @@ import functools import os import random -import socket -from collections import defaultdict from contextlib import contextmanager from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable +import numpy as np import torch -import torch.distributed as dist -from torch import inf -from torch.nn.parameter import Parameter - -from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.tensor import ColoParameter, ProcessGroup - -from .multi_tensor_apply import multi_tensor_applier - -try: - from colossalai._C import fused_optim -except: - fused_optim = None - - -def print_rank_0(msg: str, logger=None): - """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. - - Args: - msg (str): A string message to output. - logger (:class:`colossalai.logging.DistributedLogger`, optional): - The logger to record the message, defaults to None. - """ - if gpc.get_global_rank() == 0: - if logger is None: - print(msg, flush=True) - else: - logger.info(msg) def ensure_path_exists(filename: str): @@ -50,47 +18,6 @@ def ensure_path_exists(filename: str): Path(dirpath).mkdir(parents=True, exist_ok=True) -def sync_model_param(model, parallel_mode): - r"""Make sure data parameters are consistent during Data Parallel Mode. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: - for param in model.parameters(): - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) - - -def is_dp_rank_0(): - return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) - - -def is_tp_rank_0(): - return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) - - -def is_no_pp_or_last_stage(): - return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) - - -def is_using_ddp(): - return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 - - -def is_using_pp(): - return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 - - -def is_using_sequence(): - return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 - - @contextmanager def conditional_context(context_manager, enable=True): if enable: @@ -100,365 +27,10 @@ def conditional_context(context_manager, enable=True): yield -class model_branch_context(object): - - def __enter__(self): - self.env_status = env.save() - - def __exit__(self, *exc_info): - env.load(**self.env_status) - - -def is_model_parallel_parameter(p): - return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) - - def is_ddp_ignored(p): return getattr(p, '_ddp_to_ignore', False) -def _calc_l2_norm(grads): - # we should not - global fused_optim - - if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - - norm = 0.0 - if len(grads) > 0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - norm, _ = multi_tensor_applier( - fused_optim.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], - False # no per-parameter norm - ) - return norm - - -def _calc_lp(grads, norm_type): - norm = 0.0 - for grad in grads: - grad_norm = torch.norm(grad, norm_type) - norm += grad_norm**norm_type - return norm - - -def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - if torch.is_tensor(norm) and norm.device.type != 'cuda': - norm = norm.to(torch.cuda.current_device()) - return norm - - -def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: - if isinstance(norm, float): - norm = torch.Tensor([norm]) - if move_to_cuda: - norm = norm.to(torch.cuda.current_device()) - return norm - - -# ======== Gradient Clipping ========= - - -def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: - if len(params) == 0: - return 0.0 - grads = [p.grad for p in params] - use_cuda_kernel = grads[0].device.type == 'cuda' - if norm_type == inf: - local_lp = max([g.abs().max() for g in grads]) - elif norm_type == 2.0 and use_cuda_kernel: - local_lp = _calc_l2_norm(grads)**norm_type - else: - local_lp = _calc_lp(grads, norm_type) - if isinstance(local_lp, torch.Tensor): - return local_lp.item() - return local_lp - - -def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: - if len(params) == 0: - return 0.0 - buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) - for p in params: - if p.is_replicate(): - buckets[None].append(p) - else: - buckets[p.get_process_group().tp_process_group()].append(p) - total_lp = 0.0 - for group, bucket in buckets.items(): - local_lp = _compute_local_lp(bucket, norm_type) - if group is not None: - local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) - if norm_type == inf: - dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) - else: - dist.all_reduce(local_lp_tensor, group=group) - local_lp = local_lp_tensor.item() - if norm_type == inf: - total_lp = max(total_lp, local_lp) - else: - total_lp += local_lp - return total_lp - - -def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) - if norm_type == inf: - dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) - else: - dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) - total_lp = total_lp_tensor.item() - return total_lp - - -def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - grad_dtype = None - cpu_grad_params: List[ColoParameter] = [] - cuda_grad_params: List[ColoParameter] = [] - for p in parameters: - if p.grad is None: - continue - assert isinstance(p, ColoParameter) - if grad_dtype is None: - grad_dtype = p.grad.dtype - assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' - if p.grad.device.type == 'cuda': - cuda_grad_params.append(p) - else: - cpu_grad_params.append(p) - norm_type = float(norm_type) - cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) - cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) - if norm_type == inf: - total_lp = max(cpu_lp, cuda_lp) - else: - total_lp = cpu_lp + cuda_lp - return _compute_pp_grad_lp(total_lp, norm_type) - - -def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: - norm_type = float(norm_type) - total_norm = _compute_grad_lp(parameters, norm_type) - if norm_type != inf: - total_norm = total_norm**(1 / norm_type) - return total_norm - - -def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: - clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1.0: - cuda_grads: List[torch.Tensor] = [] - cpu_grads: List[torch.Tensor] = [] - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - for p in parameters: - if p.grad is None: - continue - if p.grad.device.type == 'cuda': - cuda_grads.append(p.grad.detach()) - else: - cpu_grads.append(p.grad.detach()) - if len(cuda_grads) > 0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], - clip_coef) - for g in cpu_grads: - g.mul_(clip_coef) - - -def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: - total_norm = compute_grad_norm(parameters, norm_type) - _clip_grad_norm(parameters, max_norm, total_norm) - return total_norm - - -def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): - """Clips gradient norm of an iterable of parameters whose gradients are in fp32. - - This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and - added functionality to handle model parallel parameters. - - Note: - the gradients are modified in place. - - Args: - parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): - An iterable of Tensors or a single Tensor that will have gradients normalized. - max_norm (Union[float, int]): Max norm of the gradients. - norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. - - Returns: - float: Total norm of the parameters. - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - params: List[Parameter] = [] - has_zero_shared_param: bool = False - for param in parameters: - if param.grad is not None: - # Make sure the grads are in fp32 - assert param.grad.dtype == torch.float, \ - f'expected gradient to be dtype torch.float, but got {param.grad.type()}' - if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: - has_zero_shared_param = True - params.append(param) - - if len(params) == 0: - enable_cuda_kernels = False - else: - enable_cuda_kernels = params[0].grad.device.type == 'cuda' - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - - # Parameters can be on CPU or CUDA - # If parameters are on CPU, disable CUDA kernels - - # Calculate norm. - if norm_type == inf: - total_norm = max(p.grad.data.abs().max() for p in params) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL), - async_op=False) - if has_zero_shared_param: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.DATA), - async_op=False) - total_norm = total_norm_cuda[0].item() - else: - tensor_parallel_grads = [] - no_tensor_parallel_grads = [] - zero_sharded_grads = [] - for p in params: - if is_model_parallel_parameter(p): - reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) - tensor_parallel_grads.append(p.grad.data / reductor) - elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: - zero_sharded_grads.append(p.grad.data) - else: - no_tensor_parallel_grads.append(p.grad.data) - - if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type - no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type - zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type - else: - tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) - no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) - zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) - # If norm is type of float, then we convert them into torch.Tensor. - tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) - no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) - zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) - # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors - if not enable_cuda_kernels: - tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) - no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) - zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) - - # Sum across all model-parallel GPUs. - if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: - dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) - # Sum across all zero sharded GPUs - if len(zero_sharded_grads) > 0: - dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) - no_tensor_parallel_norm += zero_sharded_norm - total_norm = tensor_parallel_norm + no_tensor_parallel_norm - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) - total_norm = total_norm**(1.0 / norm_type) - if torch.is_tensor(total_norm): - total_norm = total_norm.item() - - # Scale. - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - if enable_cuda_kernels: - grads = [p.grad.detach() for p in params] - dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) - else: - for p in params: - p.grad.detach().mul_(clip_coeff) - return total_norm - - -def count_zeros_fp32(parameters): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - total_num_zeros = 0.0 - for param in parameters: - grad_not_none = param.grad is not None - is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) - if grad_not_none and is_not_tp_duplicate: - grad = param.grad.detach() - num_zeros = grad.numel() - torch.count_nonzero(grad) - total_num_zeros = num_zeros + total_num_zeros - - total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() - - # Sum across all model-parallel GPUs. - ops = [] - ops.append( - dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) - if gpc.is_initialized(ParallelMode.PIPELINE): - ops.append( - dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) - - for req in ops: - req.wait() - total_num_zeros = total_num_zeros.item() - - return total_num_zeros - - -def copy_tensor_parallel_attributes(src_tensor, dst_tensor): - for attr in TENSOR_PARALLEL_ATTRIBUTES: - if hasattr(src_tensor, attr): - val = getattr(src_tensor, attr) - setattr(dst_tensor, attr, val) - - -def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( - ParallelMode.TENSOR) == 0) - - -@contextmanager -def switch_virtual_pipeline_parallel_rank(rank): - prev_rank = gpc.virtual_pipeline_parallel_rank - try: - gpc.set_virtual_pipeline_parallel_rank(rank) - yield - finally: - gpc.set_virtual_pipeline_parallel_rank(prev_rank) - - def disposable(func: Callable) -> Callable: executed = False @@ -489,3 +61,9 @@ def _cast_float(args, dtype: torch.dtype): elif isinstance(args, dict): args = {k: _cast_float(v, dtype) for k, v in args.items()} return args + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index 60f3ccb60883..6b5d17cf04e7 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -1,7 +1,10 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Optional + import torch +import torch.distributed as dist def set_to_cuda(models): @@ -23,7 +26,7 @@ def set_to_cuda(models): def get_current_device() -> torch.device: """ Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. + If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): return torch.device(f'cuda:{torch.cuda.current_device()}') @@ -45,3 +48,9 @@ def empty_cache(): """ if torch.cuda.is_available(): torch.cuda.empty_cache() + + +def set_device(index: Optional[int] = None) -> None: + if index is None: + index = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(index) diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 86d04c11958b..6456dfb905b0 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -1,52 +1,54 @@ -import torch.nn as nn -import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context import ParallelMode -from .common import is_using_ddp -from typing import Dict, List - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - - param_dict = get_moe_epsize_param_dict(model) - - # synchronize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) +from typing import Dict, List + +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import is_using_ddp + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not hasattr(param, 'moe_info'): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = param.moe_info.ep_size + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if is_using_ddp(): + + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in param_dict[1]: + dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) + for param in param_dict[ep_size]: + dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py index dad852a34a71..549635af4332 100644 --- a/colossalai/zero/gemini/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -3,7 +3,8 @@ import torch from torch import nn -from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses # find named_params includes replica diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py index 02c9d5754ec9..e1fe904ebf1a 100644 --- a/colossalai/zero/gemini/memory_tracer/__init__.py +++ b/colossalai/zero/gemini/memory_tracer/__init__.py @@ -3,9 +3,8 @@ from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip from .memstats_collector import MemStatsCollector # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip -from .static_memstats_collector import StaticMemStatsCollector # isort:skip __all__ = [ - 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' + 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats', + 'OrderedParamGenerator' ] diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index 83903bbf4023..b93ad2c44104 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,7 +1,6 @@ from typing import Optional from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -33,4 +32,5 @@ def record_model_data_volume(self) -> None: @property def cuda_margin_mem(self) -> float: + from colossalai.legacy.utils.memory import colo_device_memory_capacity return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 4bb585677d5b..2a65d4b55409 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ import torch -from colossalai.utils import colo_device_memory_used, get_current_device +from colossalai.utils import get_current_device class MemoryMonitor: @@ -110,6 +110,7 @@ def finish(self): return max_usage def _measure_usage(self): + from colossalai.legacy.utils import colo_device_memory_used max_usage = 0 while self.keep_measuring: max_usage = max( diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py index 0694be48550a..abb3dcc74b27 100644 --- a/colossalai/zero/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -70,7 +70,7 @@ def record_model_data_volume(self) -> None: Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: - from colossalai.zero.legacy.gemini import StatefulTensor + from colossalai.legacy.zero.gemini import StatefulTensor # The following code work for ZeroInitContext, which is deprecated in v0.1.12 cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index e5466965cc48..6656821fef74 100644 --- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,12 +1,12 @@ import torch.nn -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float -from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import ( +from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import ( GradMemStats, GradMemTracerHook, ParamMemTracerHook, ) +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import _cast_float from .memory_stats import MemStats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index cd775da5e11f..a35529723a68 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ import torch +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 4205a9891534..ece92fe02e28 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -7,9 +7,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup -from colossalai.tensor import ColoParameter -from colossalai.utils import is_model_parallel_parameter - def flatten(input_): return _flatten_dense_tensors(input_) diff --git a/docs/README.md b/docs/README.md index f0cb50ffe217..a5ae2ce96a99 100644 --- a/docs/README.md +++ b/docs/README.md @@ -108,5 +108,5 @@ We support `autodoc` to extract the docstring and transform it into a Web elemen You just need to add `{{ autodoc: }}` in your markdown as a single line. An example is given below and you can see the outcome in [this PR](https://github.com/hpcaitech/ColossalAI-Documentation/pull/175). ```markdown -{{ autodoc:colossalai.amp.apex_amp.convert_to_apex_amp }} +{{ autodoc:colossalai.legacy.amp.apex_amp.convert_to_apex_amp }} ``` diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md index 384221596885..63434a526228 100644 --- a/docs/source/en/advanced_tutorials/add_your_parallel.md +++ b/docs/source/en/advanced_tutorials/add_your_parallel.md @@ -31,7 +31,7 @@ global context for users to easily manage their process groups. If you wish to a define a new class and set it in your configuration file. To define your own way of creating process groups, you can follow the steps below to create a new distributed initialization. -1. Add your parallel mode in `colossalai.context.parallel_mode.ParallelMode`. +1. Add your parallel mode in `colossalai.legacy.context.parallel_mode.ParallelMode`. ```python class ParallelMode(Enum): GLOBAL = 'global' diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 36c94fb492cd..0218264cc258 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -37,7 +37,7 @@ import torch.nn as nn from colossalai import nn as col_nn from colossalai.amp import AMP_TYPE from colossalai.legacy.builder.pipeline import partition_uniform -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) diff --git a/docs/source/en/basics/command_line_tool.md b/docs/source/en/basics/command_line_tool.md index 48b199cf78e9..4c278aaa0c6a 100644 --- a/docs/source/en/basics/command_line_tool.md +++ b/docs/source/en/basics/command_line_tool.md @@ -30,24 +30,4 @@ This command will inform you information regarding the version compatibility and To launch distributed jobs on single or multiple nodes, the command `colossalai run` can be used for process launching. You may refer to [Launch Colossal-AI](./launch_colossalai.md) for more details. -## Tensor Parallel Micro-Benchmarking - -As Colossal-AI provides an array of tensor parallelism methods, it is not intuitive to choose one for your hardware and -model. Therefore, we provide a simple benchmarking to evaluate the performance of various tensor parallelisms on your system. -This benchmarking is run on a simple MLP model where the input data is of the shape `(batch_size, seq_length, hidden_size)`. -Based on the number of GPUs, the CLI will look for all possible tensor parallel configurations and display the benchmarking results. -You can customize the benchmarking configurations by checking out `colossalai benchmark --help`. - -```shell -# run on 4 GPUs -colossalai benchmark --gpus 4 - -# run on 8 GPUs -colossalai benchmark --gpus 8 -``` - -:::caution - -Only single-node benchmarking is supported currently. - -::: + diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md index c4b0f6557926..812b9c34e4da 100644 --- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md +++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md @@ -24,7 +24,7 @@ 并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。 Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。 -1. 在 `colossalai.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。 +1. 在 `colossalai.legacy.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。 ```python class ParallelMode(Enum): GLOBAL = 'global' diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 3f57f39f2838..a1d58e9fddc2 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -37,7 +37,7 @@ import torch.nn as nn from colossalai import nn as col_nn from colossalai.amp import AMP_TYPE from colossalai.legacy.builder.pipeline import partition_uniform -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) diff --git a/docs/source/zh-Hans/basics/command_line_tool.md b/docs/source/zh-Hans/basics/command_line_tool.md index 9b0275a6cedd..5c4c18989c17 100644 --- a/docs/source/zh-Hans/basics/command_line_tool.md +++ b/docs/source/zh-Hans/basics/command_line_tool.md @@ -26,22 +26,4 @@ Colossal-AI给用户提供了命令行工具,目前命令行工具可以用来 在分布式训练时,我们可以使用`colossalai run`来启动单节点或者多节点的多进程,详细的内容可以参考[启动 Colossal-AI](./launch_colossalai.md)。 -## 张量并行基准测试 - -Colossal-AI提供了多种张量并行,想要充分理解这些方法需要一定的学习成本,对于新手来说很难靠经验选择一个并行方式。 -所以我们提供了一个简单的基准测试,能够让用户在自己的机器上测试不同张量并行的性能。这个基准测试跑一个并行的MLP模型, -输入数据的维度为`(批大小,序列长度,隐藏层维度)`。通过指定GPU的数量,Colossal-AI会搜索所有可行的并行配置。用户可以通过查看`colossalai benchmark --help`来自定义相关的测试参数。 - -```shell -# 使用4个GPU -colossalai benchmark --gpus 4 - -# 使用8个GPU -colossalai benchmark --gpus 8 -``` - -:::caution - -目前仅支持单节点的基准测试。 - -::: + diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py index cea6ac2c36e5..e6a393a57dda 100644 --- a/examples/community/roberta/pretraining/pretrain_utils.py +++ b/examples/community/roberta/pretraining/pretrain_utils.py @@ -16,7 +16,7 @@ get_linear_schedule_with_warmup, ) -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.optimizer import FusedAdam, HybridAdam diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 53fa9f489c10..fa6457cab328 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -17,7 +17,7 @@ import colossalai from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py index 4a2c9d8a47ad..1fcaa428b277 100644 --- a/examples/community/roberta/pretraining/utils/exp_util.py +++ b/examples/community/roberta/pretraining/utils/exp_util.py @@ -5,7 +5,7 @@ import psutil import torch -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc def logging(s, log_path, print_=True, log_=True): diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index 84345f589bb5..b0a96ec70075 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -1,24 +1,26 @@ #!/bin/bash set -xe -pip install -r requirements.txt +echo "this test is slow" -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 -DIFFUSERS_OFFLINE=1 +# pip install -r requirements.txt -# "torch_ddp" "torch_ddp_fp16" "low_level_zero" -for plugin in "gemini"; do - torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ - --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ - --instance_data_dir="/data/dreambooth/Teyvat/data" \ - --output_dir="./weight_output" \ - --instance_prompt="a picture of a dog" \ - --resolution=512 \ - --plugin=$plugin \ - --train_batch_size=1 \ - --learning_rate=5e-6 \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --test_run=True \ - --num_class_images=200 -done +# HF_DATASETS_OFFLINE=1 +# TRANSFORMERS_OFFLINE=1 +# DIFFUSERS_OFFLINE=1 + +# # "torch_ddp" "torch_ddp_fp16" "low_level_zero" +# for plugin in "gemini"; do +# torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ +# --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ +# --instance_data_dir="/data/dreambooth/Teyvat/data" \ +# --output_dir="./weight_output" \ +# --instance_prompt="a picture of a dog" \ +# --resolution=512 \ +# --plugin=$plugin \ +# --train_batch_size=1 \ +# --learning_rate=5e-6 \ +# --lr_scheduler="constant" \ +# --lr_warmup_steps=0 \ +# --test_run=True \ +# --num_class_images=200 +# don diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index f60704650b7e..9b2ed3b971ae 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -7,6 +7,7 @@ from typing import Optional import torch +import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel @@ -21,13 +22,9 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext -from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() @@ -366,8 +363,8 @@ def main(args): else: colossalai.launch_from_torch(config={}, seed=args.seed) - local_rank = gpc.get_local_rank(ParallelMode.DATA) - world_size = gpc.get_world_size(ParallelMode.DATA) + local_rank = dist.get_rank() + world_size = dist.get_world_size() if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index c98950fd795d..654bce36ccb7 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -23,8 +23,8 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index e331fc8fcf10..84b02633e775 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -7,8 +7,8 @@ from gpt_modules import GPT2LMHeadModel, GPTLMLoss from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize -from colossalai.core import global_context as gpc from colossalai.initialize import launch_from_torch +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger BATCH_SIZE = 16 diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index ad69888b8cc8..30d6aab4f12f 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -3,7 +3,6 @@ from functools import partial import torch -from model_zoo import model_builder from torch import nn from tqdm import tqdm @@ -14,11 +13,12 @@ split_with_split_nodes_pass, ) from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from colossalai.pipeline.rpc.utils import rpc_run +from model_zoo import model_builder def parse_args(): diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 57ce6ab64c5b..5eaa4af4df78 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -9,11 +9,6 @@ export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export TRAIN_STEP=${TRAIN_STEP:-10} # export PYTHONPATH=$PWD:$PYTHONPATH -if [ ${USE_SHARD_INIT} = "True" ]; then - USE_SHARD_INIT="--shardinit" -else - USE_SHARD_INIT="" -fi mkdir -p gemini_logs @@ -22,4 +17,4 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ --batch_size=${BATCH_SIZE} \ --distplan=${DISTPLAN} \ --train_step=${TRAIN_STEP} \ -2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}.log diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 347251ca5631..f9d30fd15c7b 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -1,3 +1,4 @@ +import argparse import os from contextlib import nullcontext from functools import partial @@ -9,7 +10,6 @@ from commons.model_zoo import model_builder from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version -from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.booster import Booster @@ -23,7 +23,7 @@ def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", type=str, diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh index b9e4e43a8d35..db742220d97e 100644 --- a/examples/language/gpt/test_ci.sh +++ b/examples/language/gpt/test_ci.sh @@ -2,4 +2,4 @@ set -x pip install -r requirements.txt cd gemini && bash test_ci.sh -cd ../hybridparallelism && bash run.sh +# cd ../hybridparallelism && bash run.sh diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index e521193a97da..a6c80394c50f 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -6,8 +6,8 @@ from torch.nn import functional as F from torch.nn.parameter import Parameter -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.base_layer import ParallelLayer from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py index 72297c540da1..746acbf7dccd 100644 --- a/examples/language/gpt/titans/model/gpt1d.py +++ b/examples/language/gpt/titans/model/gpt1d.py @@ -9,13 +9,13 @@ from colossalai import kernel from colossalai import nn as col_nn -from colossalai.core import global_context as gpc from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer import Linear1D_Col, Linear1D_Row from colossalai.legacy.nn.layer.base_layer import ParallelLayer from colossalai.legacy.nn.layer.utils import ACT2FN, divide +from colossalai.legacy.utils.activation_checkpoint import checkpoint from colossalai.utils import checkpoint -from colossalai.utils.activation_checkpoint import checkpoint __all__ = [ 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D' diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py index 9b22d156bbcd..a9da246faf82 100644 --- a/examples/language/gpt/titans/model/pipeline_gpt1d.py +++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py @@ -7,11 +7,11 @@ from colossalai import kernel from colossalai import nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger -from colossalai.pipeline.utils import partition_uniform from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index b239b626c07f..3ed18b21fff5 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -8,14 +8,14 @@ import colossalai import colossalai.utils as utils -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.trainer import Trainer, hooks +from colossalai.legacy.zero.init_ctx import ZeroInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import LinearWarmupLR from colossalai.utils import colo_set_process_memory_fraction, is_using_pp from colossalai.utils.timer import MultiTimer -from colossalai.zero.legacy.init_ctx import ZeroInitContext def calc_local_model_size(model: torch.nn.Module): diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index a6a9ad0a312c..33aa5990f7c1 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -4,8 +4,8 @@ import colossalai from colossalai.auto_parallel.tensor_shard.initialize import initialize_model -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh index bf6275b673ff..b27e36217117 100644 --- a/examples/tutorial/auto_parallel/test_ci.sh +++ b/examples/tutorial/auto_parallel/test_ci.sh @@ -1,6 +1,8 @@ #!/bin/bash set -euxo pipefail -pip install -r requirements.txt -conda install -c conda-forge coin-or-cbc -colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py +echo "this test is outdated" + +# pip install -r requirements.txt +# conda install -c conda-forge coin-or-cbc +# colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py index fe9abf2f1955..287f62aa7a90 100644 --- a/examples/tutorial/hybrid_parallel/config.py +++ b/examples/tutorial/hybrid_parallel/config.py @@ -1,4 +1,4 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyperparameters # BATCH_SIZE is as per GPU diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py index 12cdec902400..21a568168e33 100644 --- a/examples/tutorial/hybrid_parallel/train.py +++ b/examples/tutorial/hybrid_parallel/train.py @@ -5,12 +5,12 @@ from tqdm import tqdm import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import CrossEntropyLoss +from colossalai.legacy.pipeline.pipelinable import PipelinableContext from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.utils import is_using_pp diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py index 2efa0ffd0556..c6d9f94505f1 100644 --- a/examples/tutorial/large_batch_optimizer/config.py +++ b/examples/tutorial/large_batch_optimizer/config.py @@ -1,4 +1,4 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyperparameters # BATCH_SIZE is as per GPU diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh index 89f426c542b1..f4393938220d 100644 --- a/examples/tutorial/large_batch_optimizer/test_ci.sh +++ b/examples/tutorial/large_batch_optimizer/test_ci.sh @@ -1,8 +1,9 @@ #!/bin/bash set -euxo pipefail +echo "this test is outdated" -pip install -r requirements.txt +# pip install -r requirements.txt # run test -colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars -colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb +# colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars +# colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py index 35e54582f494..6ebd8d68083d 100644 --- a/examples/tutorial/large_batch_optimizer/train.py +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -4,7 +4,7 @@ from tqdm import tqdm import colossalai -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import Lamb, Lars diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py index 7c2c152450c5..8fbed6e83d52 100644 --- a/examples/tutorial/opt/opt/colossalai_zero.py +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -2,7 +2,7 @@ from colossalai.zero.shard_utils import TensorShardStrategy except ImportError: # colossalai > 0.2.8 - from colossalai.zero.legacy import TensorShardStrategy + from colossalai.legacy.zero import TensorShardStrategy zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py index 95f0abf1d8c9..dfcd3b382d3c 100644 --- a/examples/tutorial/opt/opt/context.py +++ b/examples/tutorial/opt/opt/context.py @@ -1,7 +1,7 @@ import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc class barrier_context(): diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 91380e243fb8..8cbf3d2a2850 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -51,12 +51,13 @@ from transformers.utils.versions import require_version import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.tensor import ProcessGroup +from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils import get_current_device from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh index 431b37c12004..9cbc49c7b001 100755 --- a/examples/tutorial/opt/opt/test_ci.sh +++ b/examples/tutorial/opt/opt/test_ci.sh @@ -1,21 +1,21 @@ #!/bin/bash set -xue +echo "this test is outdated" +# pip install -r requirements.txt -pip install -r requirements.txt +# BS=4 +# MEMCAP=0 +# GPUNUM=4 +# MODLE="facebook/opt-125m" -BS=4 -MEMCAP=0 -GPUNUM=4 -MODLE="facebook/opt-125m" - -torchrun \ - --nproc_per_node ${GPUNUM} \ - --master_port 19198 \ - run_clm.py \ - -s \ - --output_dir $PWD \ - --mem_cap ${MEMCAP} \ - --model_name_or_path ${MODLE} \ - --per_device_train_batch_size ${BS} \ - --num_train_epochs 1 +# torchrun \ +# --nproc_per_node ${GPUNUM} \ +# --master_port 19198 \ +# run_clm.py \ +# -s \ +# --output_dir $PWD \ +# --mem_cap ${MEMCAP} \ +# --model_name_or_path ${MODLE} \ +# --per_device_train_batch_size ${BS} \ +# --num_train_epochs 1 diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py index 6edf9cc2c7e5..887de7164e12 100644 --- a/examples/tutorial/sequence_parallel/config.py +++ b/examples/tutorial/sequence_parallel/config.py @@ -1,4 +1,4 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyper-parameters TRAIN_ITERS = 10 diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py index 1ef2d999389f..6fdf07ba5b69 100644 --- a/examples/tutorial/sequence_parallel/data/__init__.py +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -1,10 +1,12 @@ -from colossalai.context.parallel_context import ParallelContext -from colossalai.core import global_context as gpc +import torch + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.context.parallel_context import ParallelContext +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.context import ParallelMode -from .datasets.data_samplers import build_pretraining_data_loader + from .datasets.builder import build_train_valid_test_datasets -import torch +from .datasets.data_samplers import build_pretraining_data_loader def cyclic_iter(iter): @@ -18,8 +20,7 @@ def build_train_valid_test_data_iterators(train_iters, eval_interval, eval_iters, dataloader_type='single', - **kwargs - ): + **kwargs): (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) logger = get_dist_logger() @@ -42,9 +43,7 @@ def build_train_valid_test_data_iterators(train_iters, train_samples = train_iters * global_batch_size eval_iters_ = (train_iters // eval_interval + 1) * eval_iters test_iters = eval_iters - train_val_test_num_samples = [train_samples, - eval_iters_ * global_batch_size, - test_iters * global_batch_size] + train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] logger.info(' > datasets target sizes (minimum size):') logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) @@ -56,19 +55,20 @@ def build_train_valid_test_data_iterators(train_iters, # Build dataloaders. dp_size = gpc.get_world_size(ParallelMode.DATA) - train_dataloader = build_pretraining_data_loader( - train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) - valid_dataloader = build_pretraining_data_loader( - valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) - test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + train_dataloader = build_pretraining_data_loader(train_ds, + consumed_samples=0, + micro_batch_size=global_batch_size // dp_size) + valid_dataloader = build_pretraining_data_loader(valid_ds, + consumed_samples=0, + micro_batch_size=global_batch_size // dp_size) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and train_iters > 0 do_valid = valid_dataloader is not None and eval_iters > 0 do_test = test_dataloader is not None and eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py index d092db3e7dd8..b65ca1e64f3c 100644 --- a/examples/tutorial/sequence_parallel/data/bert_helper.py +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -1,7 +1,8 @@ -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode import torch +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + _MAX_DATA_DIM = 5 @@ -22,7 +23,8 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + torch.distributed.broadcast(sizes_cuda, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)) # Move back to cpu and unpack. @@ -60,19 +62,15 @@ def broadcast_data(keys, data, datatype): """ # Build (key, size) and (key, number of elements) dictionaries along # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, - data) + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: # Check that all keys have the same data type. # Flatten the data associated with the keys - flatten_data = torch.cat( - [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() else: - flatten_data = torch.empty(total_numel, - device=torch.cuda.current_device(), - dtype=datatype) + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast torch.distributed.broadcast(flatten_data, @@ -139,7 +137,7 @@ def get_batch_for_sequence_parallel(data_iterator): seq_length = data_b['text'].size(1) sub_seq_length = seq_length // local_world_size sub_seq_start = local_rank * sub_seq_length - sub_seq_end = (local_rank+1) * sub_seq_length + sub_seq_end = (local_rank + 1) * sub_seq_length # # # Unpack. tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() @@ -156,10 +154,9 @@ class SequenceParallelDataIterator: def __init__(self, data_iter): self.data_iter = data_iter - def __iter__(self): return self.data_iter def __next__(self): - return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file + return get_batch_for_sequence_parallel(self.data_iter) diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py index d6388bd9f8e4..70c1269122dc 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -21,8 +21,8 @@ import torch from torch.utils.data import Dataset -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from ..tokenizer import get_tokenizer diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py index cf547ad97558..b9c197c95ae3 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -14,10 +14,12 @@ # limitations under the License. """Dataloaders.""" -import torch import random -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode + +import torch + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py index ee3c923e8e76..ba832b5cdce9 100644 --- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Megatron tokenizers.""" -from abc import ABC -from abc import abstractmethod -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode +from abc import ABC, abstractmethod + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from .bert_tokenization import FullTokenizer as FullBertTokenizer @@ -26,18 +25,13 @@ def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): """Initialize tokenizer.""" if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print('> building {} tokenizer ...'.format(tokenizer_type), - flush=True) + print('> building {} tokenizer ...'.format(tokenizer_type), flush=True) # Select and instantiate the tokenizer. if tokenizer_type == 'BertWordPieceLowerCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=True, - vocab_extra_ids=vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) elif tokenizer_type == 'BertWordPieceCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=False, - vocab_extra_ids=vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(tokenizer_type)) @@ -62,8 +56,8 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): after += 1 if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format( - orig_vocab_size, after - orig_vocab_size, after), flush=True) + '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after), + flush=True) return after @@ -142,8 +136,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): self._additional_special_tokens = [] # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', - 'bos_token': '[BOS]'} + SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'} self._bos_token = '[BOS]' self.add_token(self._bos_token) self._bos_token_id = self.vocab.get(self._bos_token) @@ -155,8 +148,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): # (dsachan) Add additional special tokens # These can be used as sentinel tokens in T5 model inputs additional_special_tokens = [] - additional_special_tokens.extend( - ["".format(i) for i in range(vocab_extra_ids)]) + additional_special_tokens.extend(["".format(i) for i in range(vocab_extra_ids)]) self.add_additional_special_tokens(additional_special_tokens) def add_token(self, token): diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py index e87a778cf5d5..b3f2487a438b 100644 --- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -1,37 +1,29 @@ import torch +import torch.distributed as dist import torch.nn as nn -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.logging import get_dist_logger import torch.nn.functional as F -import torch.distributed as dist + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.logging import get_dist_logger + from .cross_entropy import vocab_cross_entropy class BertLoss(nn.Module): - def forward(self, - lm_loss, - sop_logits, - loss_mask, - sentence_order): + def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() loss_mask_sum = loss_mask.sum() - lm_loss = torch.sum( - lm_loss_.view(-1) * loss_mask.reshape(-1)) + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) lm_loss /= loss_mask_sum - torch.distributed.all_reduce( - lm_loss, - group=gpc.get_group(ParallelMode.SEQUENCE) - ) + torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE)) if sop_logits is not None: - sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), - sentence_order.view(-1), - ignore_index=-1) + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) else: diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py index 54553c29a61f..ed15c6ea8054 100644 --- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -1,7 +1,8 @@ -from colossalai.context.parallel_mode import ParallelMode import torch from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.legacy.context.parallel_mode import ParallelMode + class _VocabCrossEntropy(torch.autograd.Function): @@ -24,8 +25,7 @@ def forward(ctx, vocab_parallel_logits, target): # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) @@ -58,10 +58,8 @@ def backward(ctx, grad_output): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= ( - 1.0 - target_mask.view(-1).float()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index b8adb501f95e..4ba64bbe2b9f 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -3,13 +3,13 @@ import torch import torch.nn as nn -from colossalai.context import ParallelMode -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.kernel import LayerNorm +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger -from colossalai.pipeline.utils import partition_uniform from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers.init_method import init_normal, output_init_normal diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index ea336b9d131e..9e25157e1b40 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -1,15 +1,17 @@ -import colossalai import torch import torch.nn as nn import torch.nn.functional as F -from .pooler import Pooler -from .linear import Linear -from .embedding import VocabEmbedding -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.kernel import LayerNorm from loss_func.cross_entropy import vocab_cross_entropy +import colossalai +from colossalai.kernel import LayerNorm +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + +from .embedding import VocabEmbedding +from .linear import Linear +from .pooler import Pooler + class BertLMHead(nn.Module): """Masked LM head for Bert @@ -19,10 +21,11 @@ class BertLMHead(nn.Module): layernorm_epsilon: tolerance for layer norm divisions """ - def __init__(self, - vocab_size, - hidden_size, - ): + def __init__( + self, + vocab_size, + hidden_size, + ): super(BertLMHead, self).__init__() self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py index 53a326ddacf1..dd66bfe13585 100644 --- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -1,7 +1,8 @@ -from colossalai.context.parallel_mode import ParallelMode import torch import torch.nn as nn -from colossalai.core import global_context as gpc + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc class PreProcessor(nn.Module): @@ -14,8 +15,8 @@ def bert_position_ids(self, token_ids): # Create position ids seq_length = token_ids.size(1) local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - position_ids = torch.arange(seq_length*local_rank, - seq_length * (local_rank+1), + position_ids = torch.arange(seq_length * local_rank, + seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(token_ids) diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh index 7bc20de3b6e4..1cd646526d99 100644 --- a/examples/tutorial/sequence_parallel/test_ci.sh +++ b/examples/tutorial/sequence_parallel/test_ci.sh @@ -1,7 +1,8 @@ #!/bin/bash set -euxo pipefail -pip install -r requirements.txt +echo "this test is outdated" +# pip install -r requirements.txt # run test -colossalai run --nproc_per_node 4 train.py +# colossalai run --nproc_per_node 4 train.py diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index 86c4edeb5550..b8b89cda5525 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,14 +8,15 @@ from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.kernel import LayerNorm +from colossalai.legacy.amp import AMP_TYPE +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine.schedule import PipelineSchedule +from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import FusedAdam -from colossalai.utils import MultiTimer, is_using_pp +from colossalai.utils import MultiTimer def process_batch_data(batch_data): diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index 193832ebc12d..df01e4c4847e 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -1,11 +1,14 @@ -from torchvision.models import resnet18 -from .registry import non_distributed_component_funcs -from pathlib import Path import os +from pathlib import Path + import torch -from torchvision.transforms import transforms from torchvision.datasets import CIFAR10 -from colossalai.utils import get_dataloader +from torchvision.models import resnet18 +from torchvision.transforms import transforms + +from colossalai.legacy.utils import get_dataloader + +from .registry import non_distributed_component_funcs def get_cifar10_dataloader(train): diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index f184f64b35d0..b65e6d0d8863 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -6,12 +6,12 @@ import torchvision.models as tm import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index db268b91d0a0..babdddfada18 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -8,12 +8,12 @@ from torch.fx import GraphModule import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 4e3c26c1ba9c..715f62358e2d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -13,10 +13,9 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor.process_group import ProcessGroup from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn from colossalai.utils import get_current_device -from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): @@ -70,14 +69,12 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print(strategy) print('=' * msg_length) - dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) gemini_config = dict(strict_ddp_mode=False, device=get_current_device(), placement_policy='cpu', pin_memory=True, search_range_m=128) - post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index 15610e2b50dc..593658fd1368 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -6,9 +6,9 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.utils import flat_list -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index b6a792f5652c..264331a5fef0 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -5,9 +5,9 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 3202318fb6d1..65d1e9c4d090 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -5,9 +5,9 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 13b7119424e4..2304203d1e04 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -7,8 +7,8 @@ def check_process_group_mesh_with_gpc(): - from colossalai.context import ParallelMode - from colossalai.core import global_context as gpc + from colossalai.legacy.context import ParallelMode + from colossalai.legacy.core import global_context as gpc DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, 2, 2) @@ -138,7 +138,7 @@ def run_dist(rank, world_size, port): port=port, host='localhost') # TODO(ver217): this function should be removed when gpc is removed - check_process_group_mesh_with_gpc() + # check_process_group_mesh_with_gpc() check_process_group_mesh_with_cases() diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py deleted file mode 100644 index 6af884450ad0..000000000000 --- a/tests/test_context/configs/parallel_2d_init.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=4, - mode='2d' - ) -) diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py deleted file mode 100644 index c2d896d383e2..000000000000 --- a/tests/test_context/configs/parallel_2p5d_init.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - depth=2, - mode='2.5d' - ) -) diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py deleted file mode 100644 index 0ec724f8bb4f..000000000000 --- a/tests/test_context/configs/parallel_3d_init.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - mode='3d' - ) -) diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 7c6339eff67e..c18bf56752fb 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -3,7 +3,6 @@ import torch.distributed as dist from torch.distributed import ReduceOp -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -13,7 +12,7 @@ def check_layer(rank, world_size, port): launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda() mesh_shape = (2, 2) @@ -27,8 +26,6 @@ def check_layer(rank, world_size, port): dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) assert tensor.equal(tensor_to_check) - gpc.destroy() - @pytest.mark.dist @rerun_if_address_is_in_use() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index bcac2ec426d9..6a12f5bc848e 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -4,9 +4,9 @@ from torch.utils.checkpoint import checkpoint import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index 5b327807a57b..ebcfb4d7b633 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -2,9 +2,9 @@ import torch import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index c217b96586fe..dac59c23655e 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -5,9 +5,9 @@ from torch.fx import GraphModule import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 1044be7db1f4..29135b45f997 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -5,9 +5,9 @@ import torch from torch.fx import symbolic_trace -from colossalai.core import global_context as gpc from colossalai.fx.passes import column_shard_linear_pass from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py index 55dd65201acd..db6cadfc544c 100644 --- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -1,18 +1,22 @@ +import random + +import numpy as np import torch from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass + from colossalai.fx import ColoTracer -from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo -from colossalai.pipeline.middleware.adaptor import get_fx_topology -import random -import numpy as np +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology MANUAL_SEED = 0 random.seed(MANUAL_SEED) np.random.seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED) + class MLP(torch.nn.Module): + def __init__(self, config={}): super().__init__() dim = config['dim'] @@ -27,6 +31,7 @@ def forward(self, x): x = layer(x) return x + def split_model_and_get_DAG(model, data_gen): model.eval() @@ -46,7 +51,7 @@ def split_model_and_get_DAG(model, data_gen): # apply transform passes annotated_model = balanced_split_pass(gm, 2) top_module, split_submodules = split_with_split_nodes_pass(annotated_model) - + topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): @@ -54,6 +59,7 @@ def split_model_and_get_DAG(model, data_gen): return top_module, split_submodules[0]._topo + def check_input(top_module, input_partition: Partition): partition_output = input_partition.get_output_vals() arg_pos = 0 @@ -63,13 +69,14 @@ def check_input(top_module, input_partition: Partition): to_partition_and_offset = cur_checkee.get() assert len(to_partition_and_offset) == len(node.users.keys()) arg_pos += 1 - + assert arg_pos == len(partition_output) - + + def check_submod(top_module, part_id, mid_partition: Partition): partition_input = mid_partition.get_input_vals() partition_output = mid_partition.get_output_vals() - + cnt = 1 cur_node = None for node in top_module.graph.nodes: @@ -78,15 +85,15 @@ def check_submod(top_module, part_id, mid_partition: Partition): if cnt == part_id: cur_node = node break - + assert len(partition_input) == len(cur_node.args) assert len(partition_output) == len(cur_node.users) -def check_topo(top_module, topo: Topo): + +def check_topo(top_module, topo: Topo): input_partition = topo.get_input_partition() mid_partitions = topo.get_mid_partitions() - + check_input(top_module, input_partition) for part_id, submod in mid_partitions.items(): check_submod(top_module, part_id, submod) - \ No newline at end of file diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py similarity index 94% rename from tests/test_amp/test_naive_fp16.py rename to tests/test_legacy/test_amp/test_naive_fp16.py index 6ce4c7f49725..54bf6498549c 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -4,7 +4,7 @@ import torch import colossalai -from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp +from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -78,7 +78,7 @@ def run_naive_amp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') run_naive_amp() diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py similarity index 95% rename from tests/test_amp/test_torch_fp16.py rename to tests/test_legacy/test_amp/test_torch_fp16.py index 6451aa6264a3..89810b5d0351 100644 --- a/tests/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -4,7 +4,7 @@ import torch import colossalai -from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp +from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -78,7 +78,7 @@ def run_torch_amp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') run_torch_amp() diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py index c5fb049fe93f..4851b3e36bbc 100644 --- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py @@ -1,10 +1,10 @@ import pytest import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch from colossalai.legacy.communication.p2p_v2 import _recv_object, _send_object +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 3251d8d46f0b..fccfcd973000 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py index f50982ee1c2d..a1322e6f28db 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p.py @@ -1,9 +1,6 @@ import pytest import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch from colossalai.legacy.communication.p2p import ( recv_backward, recv_forward, @@ -12,6 +9,9 @@ send_forward, send_forward_recv_backward, ) +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=2)) diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py index 040c63322f2b..f805bd19d7e8 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py @@ -1,10 +1,10 @@ import pytest import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch from colossalai.legacy.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_legacy/test_context/configs/parallel_2d_init.py b/tests/test_legacy/test_context/configs/parallel_2d_init.py new file mode 100644 index 000000000000..6cf816942fdd --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_2d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode='2d')) diff --git a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py new file mode 100644 index 000000000000..b946d45b3a91 --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode='2.5d')) diff --git a/tests/test_legacy/test_context/configs/parallel_3d_init.py b/tests/test_legacy/test_context/configs/parallel_3d_init.py new file mode 100644 index 000000000000..a1564bbb2d51 --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_3d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode='3d')) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_legacy/test_context/test_hybrid_parallel.py similarity index 95% rename from tests/test_context/test_hybrid_parallel.py rename to tests/test_legacy/test_context/test_hybrid_parallel.py index d25668afd430..05cd1d294dcd 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_legacy/test_context/test_hybrid_parallel.py @@ -6,11 +6,11 @@ import pytest import torch -from colossalai import launch -from colossalai.context import reset_seeds -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as tp_env +from colossalai.legacy import launch +from colossalai.legacy.context import reset_seeds +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as tp_env from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_legacy/test_data/test_cifar10_dataset.py similarity index 100% rename from tests/test_data/test_cifar10_dataset.py rename to tests/test_legacy/test_data/test_cifar10_dataset.py diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_legacy/test_data/test_data_parallel_sampler.py similarity index 87% rename from tests/test_data/test_data_parallel_sampler.py rename to tests/test_legacy/test_data/test_data_parallel_sampler.py index 7beef707c096..cf10fe9dfa3c 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_legacy/test_data/test_data_parallel_sampler.py @@ -10,10 +10,11 @@ from torchvision import datasets, transforms import colossalai -from colossalai.context import Config, ParallelMode -from colossalai.core import global_context as gpc +from colossalai.context import Config +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader CONFIG = Config(dict( parallel=dict( @@ -26,7 +27,7 @@ def run_data_sampler(rank, world_size, port): dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') - colossalai.launch(**dist_args) + colossalai.legacy.launch(**dist_args) print('finished initialization') # build dataset diff --git a/tests/test_legacy/test_data/test_deterministic_dataloader.py b/tests/test_legacy/test_data/test_deterministic_dataloader.py new file mode 100644 index 000000000000..421b8d255318 --- /dev/null +++ b/tests/test_legacy/test_data/test_deterministic_dataloader.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from torchvision import datasets, transforms + +import colossalai +from colossalai.context import Config +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = Config( + dict( + train_data=dict( + dataset=dict( + type='CIFAR10', + root=Path(os.environ['DATA']), + train=True, + download=True, + ), + dataloader=dict(num_workers=2, batch_size=2, shuffle=True), + ), + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + )) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + colossalai.legacy.launch(**dist_args) + + # build dataset + transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + # this is without sampler + # this should be false if data parallel sampler to given to the dataloader + assert torch.equal(img, + img_to_compare), 'Same image was distributed across ranks and expected it to be the same' + torch.cuda.empty_cache() + + +@rerun_if_address_is_in_use() +def test_data_sampler(): + spawn(run_data_sampler, 4) + + +if __name__ == '__main__': + test_data_sampler() diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py index 62493cf3712d..8499784038d2 100644 --- a/tests/test_legacy/test_engine/test_engine.py +++ b/tests/test_legacy/test_engine/test_engine.py @@ -1,8 +1,8 @@ import pytest import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.core import global_context as gpc +from colossalai.legacy.amp import AMP_TYPE +from colossalai.legacy.core import global_context as gpc from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -20,10 +20,11 @@ def run_train(model_name, amp_mode): model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model = model_builder(checkpoint=False) - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer_class(model.parameters(), lr=1e-3), - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, + optimizer=optimizer_class(model.parameters(), + lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader) try: engine.train() @@ -48,7 +49,12 @@ def run_train(model_name, amp_mode): def run_engine(rank, world_size, port): # init dist env - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config=CONFIG, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') run_train() diff --git a/tests/test_legacy/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py index 7783827c7c44..168c93c1a572 100644 --- a/tests/test_legacy/test_engine/test_gradient_accumluation.py +++ b/tests/test_legacy/test_engine/test_gradient_accumluation.py @@ -10,10 +10,10 @@ from torchvision.models import resnet18 import colossalai -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader # Config BATCH_SIZE = 2 @@ -27,7 +27,12 @@ def run_no_pipeline(rank, world_size, port): # init dist env - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config=CONFIG, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') # build model model = resnet18(num_classes=10) @@ -49,10 +54,10 @@ def run_no_pipeline(rank, world_size, port): optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) logger = get_dist_logger() rank = torch.distributed.get_rank() param_track = [] diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index dcb2be62671b..859707e6129d 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,9 +2,9 @@ import torch.distributed as dist from torch.nn import Parameter -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.legacy.nn import ( Classifier1D, Embedding1D, @@ -15,7 +15,8 @@ VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D, ) -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py index 891512542475..2a016ed7b33d 100644 --- a/tests/test_legacy/test_layers/test_1d/test_1d.py +++ b/tests/test_legacy/test_layers/test_1d/test_1d.py @@ -5,8 +5,8 @@ import torch from checks_1d.check_layer_1d import * -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 0ee88c26035f..494497be33e2 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,7 +1,7 @@ import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( Classifier2D, CrossEntropyLoss2D, @@ -15,7 +15,8 @@ VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D, ) -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index ae1d1120cfb9..034dbe5ca29c 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -3,10 +3,11 @@ import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py index bcea5ce7b25d..a4b46793f19d 100644 --- a/tests/test_legacy/test_layers/test_2d/test_2d.py +++ b/tests/test_legacy/test_layers/test_2d/test_2d.py @@ -18,8 +18,8 @@ ) from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 5a99b05cfe7e..e7a9a8be45d0 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,8 +1,8 @@ import torch from torch.nn import Parameter -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( Classifier2p5D, CrossEntropyLoss2p5D, @@ -16,7 +16,8 @@ VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D, ) -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import * diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index db19967676d2..fe78ef669bf0 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -1,9 +1,10 @@ import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import * diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py index 373d834d0032..38ba3ba78575 100644 --- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py @@ -3,8 +3,8 @@ from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index cee639a9f00a..2a9dcc3cdc16 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,8 +5,8 @@ import torch -from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.core import global_context +from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.core import global_context from colossalai.legacy.nn import ( Classifier3D, CrossEntropyLoss3D, @@ -21,8 +21,9 @@ VocabParallelEmbedding3D, ) from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.legacy.utils import print_rank_0 from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py index fde71a4a0d26..2a32d8935c00 100644 --- a/tests/test_legacy/test_layers/test_3d/test_3d.py +++ b/tests/test_legacy/test_layers/test_3d/test_3d.py @@ -15,8 +15,8 @@ check_vocab_parallel_loss, ) -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py index 0760a3f1ec38..c58445a396ec 100644 --- a/tests/test_legacy/test_layers/test_cache_embedding.py +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -14,7 +14,8 @@ ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ) -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.tensor import ColoTensor from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn NUM_EMBED, EMBED_DIM = 10, 8 @@ -359,7 +360,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # run_parallel_freq_aware_embed_columnwise(rank, world_size) run_parallel_freq_aware_embed_tablewise(rank, world_size) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index 7ff91a7b76e0..ac9493adab2e 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -1,7 +1,7 @@ import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import TransformerSelfAttentionRing from colossalai.utils import get_current_device diff --git a/tests/test_legacy/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py index b9e6c12479ee..85226f9d934a 100644 --- a/tests/test_legacy/test_layers/test_sequence/test_sequence.py +++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py @@ -3,8 +3,8 @@ import torch.distributed as dist import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -120,7 +120,7 @@ def check_ring_av(rank, world_size): def run_test(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) + colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) # check_ring_qk(rank, world_size) check_ring_av(rank, world_size) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_legacy/test_pipeline/rpc_test_utils.py similarity index 97% rename from tests/test_pipeline/rpc_test_utils.py rename to tests/test_legacy/test_pipeline/rpc_test_utils.py index dab474a4ee21..9a336c4224be 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_legacy/test_pipeline/rpc_test_utils.py @@ -10,9 +10,9 @@ from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.optim import SGD, Adam, Optimizer, RMSprop -from colossalai import launch +from colossalai.legacy import launch +from colossalai.legacy.pipeline.pipeline_process_group import ppg from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg rpc_is_initialized = _is_current_rpc_agent_set diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py similarity index 94% rename from tests/test_pipeline/test_cuda_rpc_chimera.py rename to tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py index 45ad8f828e61..3bff08318d40 100644 --- a/tests/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py @@ -1,10 +1,10 @@ import torch -from torch import nn import torch.autograd as autograd +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import nn -from colossalai.pipeline.rpc import ChimeraPipelineEngine +from colossalai.legacy.pipeline.rpc import ChimeraPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py similarity index 89% rename from tests/test_pipeline/test_cuda_rpc_optimizer.py rename to tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py index 842566730caf..eff031ff8faa 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py @@ -1,11 +1,10 @@ import torch -from torch import nn -from torch import autograd -from torch.optim import SGD, Adam, RMSprop, Optimizer +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn +from torch.optim import SGD, Adam, Optimizer, RMSprop -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py similarity index 87% rename from tests/test_pipeline/test_cuda_rpc_pipeline.py rename to tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py index 8d03e79813e8..1a6077f8d3e9 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,8 +1,8 @@ import torch +from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import nn -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from rpc_test_utils import rpc_run, parse_args, RpcTestModel +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py similarity index 91% rename from tests/test_pipeline/test_cuda_rpc_value_correctness.py rename to tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py index e6713478baec..43966ce3dbda 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py @@ -1,10 +1,9 @@ import torch -from torch import nn -from torch import autograd +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel feat_num = 100 h = 100 diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py similarity index 94% rename from tests/test_pipeline/test_middleware_1f1b.py rename to tests/test_legacy/test_pipeline/test_middleware_1f1b.py index 5b3aad703275..4e43d52f8aee 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py @@ -7,13 +7,13 @@ from rpc_test_utils import DAG_MLP, MLP from torch._C._distributed_rpc import _is_current_rpc_agent_set -from colossalai import launch from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.legacy import launch +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # global variable for model created diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_legacy/test_pipeline/test_pipelinable.py similarity index 96% rename from tests/test_pipeline/test_pipelinable.py rename to tests/test_legacy/test_pipeline/test_pipelinable.py index bb016596beea..2ba5d0aa24d8 100644 --- a/tests/test_pipeline/test_pipelinable.py +++ b/tests/test_legacy/test_pipeline/test_pipelinable.py @@ -1,7 +1,7 @@ import pytest import torch -from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.legacy.pipeline.pipelinable import PipelinableContext from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn NUM_CHUNKS = 1 diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py similarity index 91% rename from tests/test_pipeline/test_pipeline_process_group.py rename to tests/test_legacy/test_pipeline/test_pipeline_process_group.py index 2a00e3ac55b1..e6b95660279b 100644 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py @@ -3,9 +3,9 @@ import torch.distributed.rpc as rpc from rpc_test_utils import pg_parse_args, rpc_is_initialized -from colossalai.initialize import launch +from colossalai.legacy.initialize import launch +from colossalai.legacy.pipeline.pipeline_process_group import ppg from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg from colossalai.testing import spawn diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_legacy/test_tensor/common_utils/__init__.py similarity index 95% rename from tests/test_tensor/common_utils/__init__.py rename to tests/test_legacy/test_tensor/common_utils/__init__.py index 5387db70445f..9a35d02ce5ed 100644 --- a/tests/test_tensor/common_utils/__init__.py +++ b/tests/test_legacy/test_tensor/common_utils/__init__.py @@ -1 +1 @@ -from ._utils import * +from ._utils import * diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py similarity index 93% rename from tests/test_tensor/common_utils/_utils.py rename to tests/test_legacy/test_tensor/common_utils/_utils.py index b405f8cd2108..b6fea28e4c8a 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_legacy/test_tensor/common_utils/_utils.py @@ -6,9 +6,9 @@ import torch.distributed as dist from torch.testing import assert_close -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ShardSpec def set_seed(seed): diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py similarity index 91% rename from tests/test_tensor/core/test_dist_spec_mgr.py rename to tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index 89476a35b63a..b6d6bcee66ce 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -5,7 +5,7 @@ import torch.distributed as dist import colossalai -from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.legacy.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -48,7 +48,7 @@ def check_mem(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_mem() run() diff --git a/tests/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py similarity index 82% rename from tests/test_tensor/test_parameter.py rename to tests/test_legacy/test_tensor/test_parameter.py index 9c3f05da1ffa..7a8694ff6789 100644 --- a/tests/test_tensor/test_parameter.py +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -3,13 +3,13 @@ from common_utils import tensor_equal import colossalai -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import free_port @pytest.mark.skip def test_multiinheritance(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colo_param = ColoParameter(None, requires_grad=True) assert colo_param.dist_spec.placement.value == 'r' assert isinstance(colo_param, ColoTensor) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index 5fb678525bb3..84652093a9fd 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -5,9 +5,6 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch from colossalai.legacy.communication import ( recv_backward, recv_forward, @@ -18,6 +15,9 @@ send_forward_recv_backward, send_obj_meta, ) +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py index 6d7bf6b3d89f..fd94c279b6fb 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -11,11 +11,11 @@ from torchvision.models import resnet18 import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import get_dataloader, print_rank_0 from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader, print_rank_0 BATCH_SIZE = 8 @@ -63,7 +63,7 @@ def forward(self, x): optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) # initialize - engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) + engine, train_dataloader, _, _ = colossalai.legacy.initialize(model, optimizer, criterion, train_dataloader) # build pipeline schedule schedule = engine.schedule diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py index dab0e53a4c32..4a240533474c 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -2,7 +2,7 @@ import torch import colossalai -from colossalai.amp.amp_type import AMP_TYPE +from colossalai.legacy.amp.amp_type import AMP_TYPE from colossalai.legacy.trainer import Trainer from colossalai.logging import get_dist_logger from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -22,10 +22,10 @@ def run_trainer(model_name): model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() optimizer = optimizer_class(model.parameters(), lr=1e-3) - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *_ = colossalai.legacy.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) @@ -45,7 +45,12 @@ def run_trainer(model_name): def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config=CONFIG, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') @pytest.mark.dist diff --git a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py index 7dfbec854ccc..521b2f32f22d 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py @@ -10,12 +10,13 @@ from torchvision.models import resnet18 import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.legacy.trainer import Trainer +from colossalai.legacy.utils import get_dataloader from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import MultiTimer, get_dataloader +from colossalai.utils import MultiTimer BATCH_SIZE = 4 IMG_SIZE = 32 @@ -28,7 +29,12 @@ def run_trainer_with_pipeline(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config=CONFIG, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') # build model model = resnet18(num_classes=10) @@ -63,10 +69,10 @@ def forward(self, x): optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_legacy/test_utils/test_activation_checkpointing.py similarity index 94% rename from tests/test_utils/test_activation_checkpointing.py rename to tests/test_legacy/test_utils/test_activation_checkpointing.py index b7764c2f4371..19984ae120b5 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_legacy/test_utils/test_activation_checkpointing.py @@ -5,10 +5,10 @@ import torch import torch.nn.functional as F -from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.legacy.utils.activation_checkpoint import checkpoint from colossalai.testing import clear_cache_before_run, parameterize -from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight): diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py similarity index 83% rename from tests/test_utils/test_checkpoint/test_checkpoint_1d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py index 9c3a7e2161d2..88cd89a217fe 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -8,17 +8,17 @@ import torch.nn as nn import colossalai.legacy.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py similarity index 83% rename from tests/test_utils/test_checkpoint/test_checkpoint_2d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py index 03b2e4f2a9b2..591cd714fc65 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -8,17 +8,17 @@ import torch.nn as nn import colossalai.legacy.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py similarity index 84% rename from tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py index cafffd0a6202..b165b4276f10 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -8,17 +8,17 @@ import torch.nn as nn import colossalai.legacy.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py similarity index 83% rename from tests/test_utils/test_checkpoint/test_checkpoint_3d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py index 9b43be9e8cc5..2ce054d33b2d 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -8,17 +8,17 @@ import torch.nn as nn import colossalai.legacy.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py similarity index 76% rename from tests/test_utils/test_memory.py rename to tests/test_legacy/test_utils/test_memory.py index c88c2f8ec3c5..2e25dc773b68 100644 --- a/tests/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -1,9 +1,9 @@ import pytest import colossalai +from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): @@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py similarity index 91% rename from tests/test_utils/test_norm_gradient_clipping.py rename to tests/test_legacy/test_utils/test_norm_gradient_clipping.py index 4fd7c3c60a95..918f174aba76 100644 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -4,12 +4,12 @@ from torch.nn.utils import clip_grad_norm_ import colossalai +from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.legacy.utils.common import clip_grad_norm from colossalai.logging import disable_existing_loggers -from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.common import clip_grad_norm def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -62,7 +62,7 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_grad_clip_norm(world_size=world_size) diff --git a/tests/test_utils/test_commons.py b/tests/test_legacy/test_zero/test_commons.py similarity index 82% rename from tests/test_utils/test_commons.py rename to tests/test_legacy/test_zero/test_commons.py index 2633d7da21aa..42a9f1eecb95 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_legacy/test_zero/test_commons.py @@ -1,13 +1,13 @@ import torch import colossalai +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.legacy.zero.sharded_param import ShardedTensor from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline -from colossalai.zero.legacy.sharded_param import ShardedTensor def run_tensor_move(rank, world_size, port): - colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 39603c158731..c096b6075005 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,9 +3,9 @@ import torch.nn as nn import colossalai -from colossalai.context import ParallelMode from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index a43ae764dccd..35fde6f10f3f 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -2,8 +2,8 @@ import torch import colossalai -from colossalai.amp import convert_to_apex_amp from colossalai.context import MOE_CONTEXT +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 2c68633aabc8..4a3199c1c53d 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -1,7 +1,7 @@ import pytest import torch +import torch.distributed as dist -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -184,7 +184,7 @@ def check_comm(rank, world_size, port): launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 2) # [[0, 1, @@ -205,7 +205,6 @@ def check_comm(rank, world_size, port): # test all reduce in 1D flatten device mesh check_all_reduce_in_flatten_device_mesh(device_mesh, rank) - gpc.destroy() @pytest.mark.dist diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 95fcd2aaf8f3..a1ea2946e6e7 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,7 +1,7 @@ import pytest import torch +import torch.distributed as dist -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -127,7 +127,7 @@ def check_comm(rank, world_size, port): launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 2) # [[0, 1, @@ -149,8 +149,6 @@ def check_comm(rank, world_size, port): check_all_reduce_fwd(process_group_dict, rank) check_all_reduce_bwd(process_group_dict, rank) - gpc.destroy() - @pytest.mark.dist @rerun_if_address_is_in_use() diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index 9122808eb5a3..bd71bffccc70 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -1,7 +1,7 @@ import pytest import torch +import torch.distributed as dist -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -295,7 +295,7 @@ def check_comm(rank, world_size, port): launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') physical_mesh_id = torch.arange(0, 8) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 4) # [[0, 1, 2, 3], diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py deleted file mode 100644 index e99cf388e929..000000000000 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import checkpoint, clip_grad_norm_fp32 -from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy -from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 - - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward, False) - return module - - -class Net(nn.Module): - - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 1) - if checkpoint: - self.fc1 = checkpoint_wrapper(self.fc1) - self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(x) - loss = y.sum() - loss = loss.float() - loss.backward() - clip_grad(model, norm_type) - optimizer.step() - - -def clip_grad(model, norm_type): - if isinstance(model, DDP): - clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) - else: - clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_grads(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(4) - if rank >= len(chunks): - continue - grad = chunks[rank] - if zero_p.zero_shard_padding > 0: - zero_grad = zero_grad[:-zero_p.zero_shard_padding] - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_shard_padding = zero_p.zero_shard_padding - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(4) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_shard_padding > 0: - zero_p = zero_p[:-zero_shard_padding] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_zero_clip_grad(): - world_size = 4 - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_clip_grad() diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index d6c4f8bd8aac..f05ccfdbd41b 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -6,7 +6,6 @@ from colossalai.tensor import ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.gemini.chunk import ChunkManager -from tests.test_tensor.common_utils import debug_print CUDA_MEM_0 = {False: 512, True: 1024} CUDA_MEM_1 = {False: 0, True: 1024} @@ -16,7 +15,6 @@ @parameterize('keep_gathered', [True, False]) @parameterize('pin_memory', [True, False]) def exam_chunk_memory(keep_gathered, pin_memory): - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)] config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 4cbf564ecfb9..fabdd6072c31 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -5,15 +5,15 @@ from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index a80a2f62de22..614a96ccdbcd 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -4,12 +4,12 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed # run gemini use the runtime memory tracer diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 82b9133b89c1..860d6efa899a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -5,14 +5,14 @@ from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 20d145f9661f..99ee08c1d7e7 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,15 +7,15 @@ from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index edcbada0acbb..3454959199d2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,15 +5,15 @@ from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 656bd709e2a1..602e3ad3519d 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -4,10 +4,10 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 09725e11ec0c..5f7b51510d58 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -5,10 +5,10 @@ import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py deleted file mode 100644 index 4a2b49f63b7e..000000000000 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest -import torch -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.testing import assert_close - -import colossalai -from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer -from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal - - -def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): - return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) - - -class MlpModel(nn.Module): - - def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(32, 128) - self.act = nn.GELU() - self.linear2 = nn.Linear(128, 32) - - def forward(self, x): - y = self.linear1(x) - y = self.act(y) - y = self.linear2(y) - return x + y - - -@parameterize("overlap_flag", [False, True]) -@parameterize("partition_flag", [False, True]) -def exam_zero_with_tp(overlap_flag, partition_flag): - set_seed(233010) - tp_pg = ProcessGroup(tp_degree=2) - - with ColoInitContext(device=get_current_device(), default_pg=tp_pg): - hybrid_model = MlpModel() - torch_model = MlpModel().cuda() - for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): - pt.data.copy_(ph.data) - - for name, param in hybrid_model.named_parameters(): - if 'linear1' in name: - split_param_row_tp1d(param, tp_pg) - param.compute_spec.set_output_replicate(False) - if 'linear2.weight' in name: - split_param_col_tp1d(param, tp_pg) - - torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 - hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) - hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, - initial_scale=2, - clip_grad_norm=1.0, - overlap_communication=overlap_flag, - partition_grad=partition_flag, - dp_process_group=tp_pg.dp_process_group(), - tp_process_group=tp_pg.tp_process_group()) - - dp_local_rank = tp_pg.dp_local_rank() - set_seed(255 + dp_local_rank) - - data = torch.randn(8, 32, device=get_current_device()) - torch_loss = torch_model(data).sum() - hybrid_loss = hybrid_model(data).sum() - assert_close(torch_loss, hybrid_loss) - - torch_loss.backward() - torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) - hybrid_optim.backward(hybrid_loss) - - torch_optim.step() - hybrid_optim.step() - - for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()): - assert strict_shard_equal(pt.data, ph.data, tp_pg) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_with_tp() - - -@pytest.mark.skip('this will be rewritten by shardformer') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_zero_with_tp(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_zero_with_tp() From 3c6b831c264d0657a97034b5cf036c913a762083 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:52:42 +0800 Subject: [PATCH 23/58] [format] applied code formatting on changed files in pull request 4743 (#4750) Co-authored-by: github-actions --- .../language/gpt/experiments/pipeline_parallel/train_gpt_pp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index 30d6aab4f12f..749243e57836 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -3,6 +3,7 @@ from functools import partial import torch +from model_zoo import model_builder from torch import nn from tqdm import tqdm @@ -18,7 +19,6 @@ from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from model_zoo import model_builder def parse_args(): From 079bf3cb26a502fc647b1aad15fd14d6266be66c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 19 Sep 2023 14:20:26 +0800 Subject: [PATCH 24/58] [misc] update pre-commit and run all files (#4752) * [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format --- .flake8 | 22 - .github/workflows/scripts/check_doc_i18n.py | 12 +- .../example_checks/check_dispatch_inputs.py | 10 +- .../example_checks/check_example_weekly.py | 10 +- .../example_checks/detect_changed_example.py | 6 +- .../generate_leaderboard_and_send_to_lark.py | 137 +- .../scripts/generate_release_draft.py | 51 +- .../workflows/scripts/send_message_to_lark.py | 6 +- .isort.cfg | 1 + .pre-commit-config.yaml | 18 +- .style.yapf | 5 - .../benchmarks/benchmark_opt_lora_dummy.py | 212 +- .../Chat/benchmarks/ray/1mmt_dummy.py | 116 +- .../Chat/benchmarks/ray/mmmt_dummy.py | 128 +- applications/Chat/coati/dataset/__init__.py | 9 +- .../Chat/coati/dataset/conversation.py | 18 +- .../Chat/coati/dataset/prompt_dataset.py | 20 +- .../Chat/coati/dataset/reward_dataset.py | 102 +- .../Chat/coati/dataset/sft_dataset.py | 127 +- .../Chat/coati/experience_buffer/__init__.py | 2 +- .../Chat/coati/experience_buffer/base.py | 6 +- .../Chat/coati/experience_buffer/naive.py | 12 +- .../Chat/coati/experience_buffer/utils.py | 15 +- .../Chat/coati/experience_maker/__init__.py | 2 +- .../Chat/coati/experience_maker/base.py | 11 +- .../Chat/coati/experience_maker/naive.py | 11 +- applications/Chat/coati/kernels/__init__.py | 4 +- applications/Chat/coati/kernels/opt_attn.py | 23 +- applications/Chat/coati/models/__init__.py | 11 +- .../Chat/coati/models/base/__init__.py | 9 +- applications/Chat/coati/models/base/actor.py | 13 +- applications/Chat/coati/models/base/critic.py | 15 +- .../Chat/coati/models/base/reward_model.py | 16 +- .../Chat/coati/models/bloom/__init__.py | 2 +- .../Chat/coati/models/bloom/bloom_actor.py | 17 +- .../Chat/coati/models/bloom/bloom_critic.py | 17 +- .../Chat/coati/models/bloom/bloom_rm.py | 14 +- .../Chat/coati/models/chatglm/__init__.py | 2 +- .../coati/models/chatglm/chatglm_actor.py | 13 +- .../coati/models/chatglm/chatglm_tokenizer.py | 118 +- .../models/chatglm/configuration_chatglm.py | 48 +- .../coati/models/chatglm/modeling_chatglm.py | 406 +- applications/Chat/coati/models/generation.py | 101 +- .../Chat/coati/models/gpt/__init__.py | 2 +- .../Chat/coati/models/gpt/gpt_actor.py | 16 +- .../Chat/coati/models/gpt/gpt_critic.py | 14 +- applications/Chat/coati/models/gpt/gpt_rm.py | 12 +- .../Chat/coati/models/llama/__init__.py | 2 +- .../Chat/coati/models/llama/llama_actor.py | 18 +- .../Chat/coati/models/llama/llama_critic.py | 15 +- .../Chat/coati/models/llama/llama_rm.py | 15 +- applications/Chat/coati/models/lora.py | 32 +- applications/Chat/coati/models/loss.py | 28 +- .../Chat/coati/models/opt/__init__.py | 2 +- .../Chat/coati/models/opt/opt_actor.py | 14 +- .../Chat/coati/models/opt/opt_critic.py | 14 +- applications/Chat/coati/models/opt/opt_rm.py | 12 +- applications/Chat/coati/models/utils.py | 20 +- applications/Chat/coati/quant/__init__.py | 4 +- .../Chat/coati/quant/llama_gptq/__init__.py | 2 +- .../Chat/coati/quant/llama_gptq/loader.py | 5 +- .../coati/quant/llama_gptq/model_utils.py | 5 +- .../Chat/coati/quant/llama_gptq/quant.py | 36 +- applications/Chat/coati/quant/utils.py | 3 +- applications/Chat/coati/ray/callbacks/base.py | 3 +- .../ray/callbacks/performance_evaluator.py | 58 +- .../Chat/coati/ray/detached_replay_buffer.py | 25 +- .../Chat/coati/ray/detached_trainer_base.py | 38 +- .../Chat/coati/ray/detached_trainer_ppo.py | 85 +- .../Chat/coati/ray/experience_maker_holder.py | 113 +- .../Chat/coati/ray/lora_constructor.py | 53 +- applications/Chat/coati/ray/utils.py | 80 +- applications/Chat/coati/trainer/__init__.py | 6 +- applications/Chat/coati/trainer/base.py | 14 +- .../Chat/coati/trainer/callbacks/__init__.py | 2 +- .../Chat/coati/trainer/callbacks/base.py | 2 +- .../callbacks/performance_evaluator.py | 48 +- .../trainer/callbacks/save_checkpoint.py | 29 +- applications/Chat/coati/trainer/ppo.py | 101 +- applications/Chat/coati/trainer/rm.py | 18 +- applications/Chat/coati/trainer/sft.py | 51 +- .../Chat/coati/trainer/strategies/__init__.py | 5 +- .../Chat/coati/trainer/strategies/base.py | 24 +- .../coati/trainer/strategies/colossalai.py | 113 +- .../Chat/coati/trainer/strategies/ddp.py | 62 +- .../Chat/coati/trainer/strategies/sampler.py | 7 +- applications/Chat/coati/trainer/utils.py | 1 - .../Chat/evaluate/config/config_cn.json | 24 +- applications/Chat/evaluate/eval.py | 88 +- applications/Chat/evaluate/evaluator.py | 46 +- applications/Chat/evaluate/gpt_evaluate.py | 136 +- applications/Chat/evaluate/metrics.py | 39 +- .../Chat/evaluate/unieval/__init__.py | 7 +- .../Chat/evaluate/unieval/evaluator.py | 234 +- applications/Chat/evaluate/unieval/scorer.py | 47 +- applications/Chat/evaluate/unieval/utils.py | 165 +- applications/Chat/evaluate/utils.py | 9 +- .../examples/community/peft/easy_dataset.py | 108 +- .../examples/community/peft/easy_models.py | 39 +- .../community/peft/train_peft_prompts.py | 156 +- .../examples/community/peft/train_peft_sft.py | 171 +- .../examples/community/ray/ray_job_script.py | 25 +- .../community/ray/train_prompts_on_ray.py | 250 +- applications/Chat/examples/download_model.py | 19 +- .../examples/generate_conversation_dataset.py | 33 +- .../Chat/examples/generate_prompt_dataset.py | 18 +- applications/Chat/examples/inference.py | 51 +- applications/Chat/examples/ray/1mmt_prompt.py | 102 +- applications/Chat/examples/ray/mmmt_prompt.py | 116 +- applications/Chat/examples/requirements.txt | 2 +- applications/Chat/examples/train_prompts.py | 174 +- .../Chat/examples/train_reward_model.py | 205 +- applications/Chat/examples/train_sft.py | 233 +- applications/Chat/inference/benchmark.py | 48 +- applications/Chat/inference/locustfile.py | 30 +- applications/Chat/inference/server.py | 99 +- .../Chat/inference/tests/test_chat_prompt.py | 70 +- applications/Chat/inference/utils.py | 91 +- applications/Chat/requirements-test.txt | 2 +- applications/Chat/setup.py | 42 +- applications/Chat/tests/test_checkpoint.py | 29 +- applications/Chat/tests/test_dataset.py | 115 +- applications/Chat/tests/test_experience.py | 44 +- applications/Chat/tests/test_models.py | 153 +- colossalai/__init__.py | 6 +- .../_subclasses/_meta_registration.py | 161 +- .../_analyzer/_subclasses/_monkey_patch.py | 3 +- .../_analyzer/_subclasses/flop_tensor.py | 77 +- .../_analyzer/_subclasses/meta_tensor.py | 46 +- colossalai/_analyzer/fx/codegen.py | 181 +- colossalai/_analyzer/fx/graph_module.py | 54 +- colossalai/_analyzer/fx/node_util.py | 54 +- .../_analyzer/fx/passes/graph_profile.py | 104 +- colossalai/_analyzer/fx/passes/shape_prop.py | 36 +- colossalai/_analyzer/fx/symbolic_profile.py | 4 - .../_analyzer/fx/tracer/bias_addition.py | 190 +- .../_analyzer/fx/tracer/custom_leaf_module.py | 1 + colossalai/_analyzer/fx/tracer/proxy.py | 15 +- .../_analyzer/fx/tracer/symbolic_trace.py | 12 +- colossalai/_analyzer/fx/tracer/tracer.py | 106 +- .../amp/naive_amp/grad_scaler/__init__.py | 2 +- .../naive_amp/grad_scaler/base_grad_scaler.py | 17 +- .../grad_scaler/constant_grad_scaler.py | 3 +- .../grad_scaler/dynamic_grad_scaler.py | 64 +- .../mixed_precision_mixin/__init__.py | 6 +- .../naive_amp/mixed_precision_mixin/base.py | 9 +- .../naive_amp/mixed_precision_mixin/fp16.py | 37 +- .../naive_amp/mixed_precision_optimizer.py | 101 +- .../auto_parallel/checkpoint/build_c_ext.py | 16 +- .../checkpoint/ckpt_solver_base.py | 35 +- .../checkpoint/ckpt_solver_chen.py | 7 +- .../checkpoint/ckpt_solver_rotor.py | 118 +- .../auto_parallel/checkpoint/operation.py | 39 +- .../auto_parallel/meta_profiler/constants.py | 2 - .../meta_profiler/meta_registry/activation.py | 43 +- .../meta_registry/binary_elementwise_ops.py | 6 +- .../meta_profiler/meta_registry/conv.py | 62 +- .../meta_profiler/meta_registry/embedding.py | 12 +- .../meta_profiler/meta_registry/linear.py | 307 +- .../meta_profiler/meta_registry/non_spmd.py | 2 +- .../meta_profiler/meta_registry/norm.py | 102 +- .../meta_profiler/meta_registry/pooling.py | 14 +- .../meta_profiler/meta_registry/tensor.py | 43 +- .../meta_profiler/meta_registry/where.py | 27 +- .../auto_parallel/meta_profiler/registry.py | 8 +- .../meta_profiler/shard_metainfo.py | 32 +- .../auto_parallel/offload/amp_optimizer.py | 60 +- .../offload/base_offload_module.py | 12 +- .../auto_parallel/offload/mem_optimize.py | 14 +- colossalai/auto_parallel/offload/region.py | 10 +- .../auto_parallel/offload/region_manager.py | 137 +- colossalai/auto_parallel/offload/runtime.py | 68 +- colossalai/auto_parallel/offload/solver.py | 105 +- .../offload/training_simulator.py | 130 +- colossalai/auto_parallel/offload/util.py | 22 +- .../passes/comm_metainfo_pass.py | 57 +- .../auto_parallel/passes/meta_info_prop.py | 22 +- .../passes/runtime_apply_pass.py | 129 +- .../passes/runtime_preparation_pass.py | 167 +- .../auto_parallel/tensor_shard/constants.py | 56 +- .../auto_parallel/tensor_shard/initialize.py | 256 +- .../tensor_shard/node_handler/__init__.py | 36 +- .../node_handler/addmm_handler.py | 46 +- .../node_handler/batch_norm_handler.py | 59 +- .../binary_elementwise_handler.py | 48 +- .../tensor_shard/node_handler/bmm_handler.py | 44 +- .../tensor_shard/node_handler/conv_handler.py | 68 +- .../node_handler/default_reshape_handler.py | 18 +- .../node_handler/embedding_handler.py | 104 +- .../node_handler/getattr_handler.py | 2 +- .../node_handler/getitem_handler.py | 8 +- .../node_handler/layer_norm_handler.py | 30 +- .../node_handler/linear_handler.py | 175 +- .../node_handler/matmul_handler.py | 134 +- .../tensor_shard/node_handler/node_handler.py | 89 +- .../node_handler/normal_pooling_handler.py | 10 +- .../node_handler/output_handler.py | 11 +- .../node_handler/permute_handler.py | 14 +- .../node_handler/placeholder_handler.py | 10 +- .../tensor_shard/node_handler/registry.py | 6 +- .../node_handler/softmax_handler.py | 8 +- .../node_handler/split_handler.py | 8 +- .../node_handler/strategy/__init__.py | 34 +- .../strategy/batch_norm_generator.py | 176 +- .../strategy/binary_elementwise_generator.py | 40 +- .../strategy/conv_strategy_generator.py | 231 +- .../strategy/embedding_generator.py | 116 +- .../strategy/getattr_generator.py | 13 +- .../strategy/getitem_generator.py | 54 +- .../strategy/layer_norm_generator.py | 57 +- .../strategy/matmul_strategy_generator.py | 688 +- .../strategy/normal_pooling_generator.py | 35 +- .../node_handler/strategy/output_generator.py | 51 +- .../strategy/placeholder_generator.py | 43 +- .../strategy/reshape_generator.py | 97 +- .../strategy/softmax_generator.py | 51 +- .../strategy/strategy_generator.py | 106 +- .../node_handler/strategy/sum_generator.py | 54 +- .../strategy/tensor_constructor_generator.py | 29 +- .../strategy/unary_elementwise_generator.py | 23 +- .../node_handler/strategy/where_generator.py | 28 +- .../tensor_shard/node_handler/sum_handler.py | 8 +- .../tensor_constructor_handler.py | 2 +- .../node_handler/transpose_handler.py | 10 +- .../node_handler/unary_elementwise_handler.py | 10 +- .../tensor_shard/node_handler/view_handler.py | 6 +- .../node_handler/where_handler.py | 33 +- .../auto_parallel/tensor_shard/options.py | 6 +- .../tensor_shard/sharding_strategy.py | 46 +- .../tensor_shard/solver/__init__.py | 2 +- .../tensor_shard/solver/cost_graph.py | 18 +- .../tensor_shard/solver/graph_analysis.py | 25 +- .../tensor_shard/solver/solver.py | 128 +- .../solver/strategies_constructor.py | 118 +- .../tensor_shard/utils/__init__.py | 22 +- .../tensor_shard/utils/broadcast.py | 42 +- .../tensor_shard/utils/factory.py | 57 +- .../auto_parallel/tensor_shard/utils/misc.py | 27 +- .../tensor_shard/utils/reshape.py | 16 +- .../tensor_shard/utils/sharding.py | 23 +- colossalai/autochunk/autochunk_codegen.py | 162 +- colossalai/autochunk/estimate_memory.py | 24 +- colossalai/autochunk/search_chunk.py | 39 +- colossalai/autochunk/select_chunk.py | 70 +- colossalai/autochunk/trace_flow.py | 36 +- colossalai/autochunk/trace_indice.py | 30 +- colossalai/autochunk/utils.py | 24 +- colossalai/booster/accelerator.py | 18 +- colossalai/booster/booster.py | 103 +- .../booster/mixed_precision/__init__.py | 22 +- .../booster/mixed_precision/fp16_apex.py | 26 +- .../booster/mixed_precision/fp16_naive.py | 18 +- .../booster/mixed_precision/fp16_torch.py | 77 +- colossalai/booster/plugin/__init__.py | 7 +- colossalai/booster/plugin/dp_plugin_base.py | 38 +- colossalai/booster/plugin/gemini_plugin.py | 138 +- .../booster/plugin/hybrid_parallel_plugin.py | 428 +- .../booster/plugin/low_level_zero_plugin.py | 109 +- colossalai/booster/plugin/plugin_base.py | 29 +- colossalai/booster/plugin/pp_plugin_base.py | 17 +- colossalai/booster/plugin/torch_ddp_plugin.py | 75 +- .../booster/plugin/torch_fsdp_plugin.py | 69 +- colossalai/checkpoint_io/__init__.py | 2 +- .../checkpoint_io/checkpoint_io_base.py | 69 +- .../checkpoint_io/general_checkpoint_io.py | 100 +- .../hybrid_parallel_checkpoint_io.py | 297 +- colossalai/checkpoint_io/index_file.py | 6 +- colossalai/checkpoint_io/utils.py | 165 +- colossalai/cli/__init__.py | 2 +- colossalai/cli/check/__init__.py | 5 +- colossalai/cli/check/check_installation.py | 29 +- colossalai/cli/cli.py | 5 +- colossalai/cli/launcher/__init__.py | 99 +- colossalai/cli/launcher/hostinfo.py | 5 +- colossalai/cli/launcher/multinode_runner.py | 22 +- colossalai/cli/launcher/run.py | 70 +- colossalai/cluster/__init__.py | 2 +- colossalai/cluster/device_mesh_manager.py | 21 +- colossalai/cluster/dist_coordinator.py | 12 +- colossalai/cluster/process_group_manager.py | 8 +- colossalai/cluster/process_group_mesh.py | 23 +- colossalai/context/__init__.py | 4 +- colossalai/context/config.py | 11 +- colossalai/context/moe_context.py | 26 +- colossalai/context/singleton_meta.py | 5 +- colossalai/device/__init__.py | 2 +- colossalai/device/alpha_beta_profiler.py | 71 +- colossalai/device/calc_pipeline_strategy.py | 58 +- colossalai/device/device_mesh.py | 113 +- colossalai/fx/_compatibility.py | 8 +- colossalai/fx/_meta_regist_12.py | 144 +- .../codegen/activation_checkpoint_codegen.py | 368 +- colossalai/fx/graph_module.py | 56 +- .../fx/passes/adding_split_node_pass.py | 81 +- colossalai/fx/passes/concrete_info_prop.py | 82 +- .../adding_shape_consistency_pass.py | 63 +- colossalai/fx/passes/meta_info_prop.py | 97 +- colossalai/fx/passes/passes_for_gpt2_test.py | 83 +- colossalai/fx/passes/shard_1d_pass.py | 47 +- colossalai/fx/passes/split_module.py | 67 +- colossalai/fx/passes/utils.py | 29 +- colossalai/fx/profiler/__init__.py | 11 +- colossalai/fx/profiler/constants.py | 2 +- colossalai/fx/profiler/dataflow.py | 19 +- .../fx/profiler/experimental/constants.py | 34 +- .../fx/profiler/experimental/profiler.py | 27 +- .../profiler_function/activation_function.py | 2 + .../profiler_function/arithmetic.py | 36 +- .../profiler_function/embedding.py | 4 +- .../experimental/profiler_function/linear.py | 2 + .../profiler_function/normalization.py | 14 +- .../experimental/profiler_function/pooling.py | 4 +- .../profiler_function/python_ops.py | 2 +- .../profiler_function/torch_ops.py | 14 +- .../profiler_module/activation_function.py | 2 + .../experimental/profiler_module/attention.py | 40 +- .../profiler_module/convolution.py | 86 +- .../experimental/profiler_module/dropout.py | 2 + .../experimental/profiler_module/linear.py | 2 + .../profiler_module/normalization.py | 9 +- .../experimental/profiler_module/pooling.py | 2 + .../experimental/profiler_module/rnn.py | 27 +- .../experimental/profiler_module/torch_op.py | 5 +- .../fx/profiler/experimental/registry.py | 6 +- .../fx/profiler/experimental/shard_utils.py | 8 +- colossalai/fx/profiler/memory_utils.py | 5 +- colossalai/fx/profiler/opcount.py | 37 +- colossalai/fx/profiler/profiler.py | 52 +- colossalai/fx/profiler/shard_utils.py | 4 +- colossalai/fx/profiler/tensor.py | 31 +- colossalai/fx/proxy.py | 16 +- colossalai/fx/tracer/_meta_trace.py | 54 +- colossalai/fx/tracer/_tracer_utils.py | 7 +- .../patched_bias_addition_function/addbmm.py | 28 +- .../patched_bias_addition_function/addmm.py | 26 +- .../bias_addition_function.py | 8 +- .../patched_bias_addition_function/linear.py | 12 +- .../bias_addition_module.py | 14 +- .../patched_bias_addition_module/conv.py | 21 +- .../patched_bias_addition_module/linear.py | 2 - colossalai/fx/tracer/experimental.py | 229 +- .../patched_function/activation_function.py | 2 +- .../meta_patch/patched_function/arithmetic.py | 12 +- .../patched_function/convolution.py | 59 +- .../meta_patch/patched_function/embedding.py | 10 +- .../patched_function/normalization.py | 15 +- .../meta_patch/patched_function/python_ops.py | 4 +- .../meta_patch/patched_function/torch_ops.py | 25 +- .../meta_patch/patched_module/__init__.py | 2 +- .../patched_module/activation_function.py | 2 +- .../meta_patch/patched_module/convolution.py | 96 +- .../meta_patch/patched_module/embedding.py | 2 +- .../meta_patch/patched_module/linear.py | 4 +- .../patched_module/normalization.py | 1 + .../meta_patch/patched_module/pooling.py | 30 +- .../tracer/meta_patch/patched_module/rnn.py | 12 +- colossalai/fx/tracer/registry.py | 12 +- colossalai/fx/tracer/tracer.py | 71 +- .../inference/tensor_parallel/__init__.py | 2 +- .../tensor_parallel/batch_infer_state.py | 15 +- .../inference/tensor_parallel/engine.py | 77 +- .../tensor_parallel/kvcache_manager.py | 51 +- .../tensor_parallel/modeling/__init__.py | 2 +- .../tensor_parallel/modeling/bloom.py | 143 +- .../tensor_parallel/modeling/llama.py | 106 +- .../tensor_parallel/policies/__init__.py | 2 +- .../tensor_parallel/policies/bloom.py | 33 +- .../tensor_parallel/policies/llama.py | 28 +- colossalai/initialize.py | 145 +- colossalai/interface/__init__.py | 2 +- colossalai/interface/model.py | 4 +- colossalai/interface/optimizer.py | 22 +- colossalai/kernel/cuda_native/__init__.py | 8 +- colossalai/kernel/cuda_native/csrc/compat.h | 2 +- .../cuda_native/csrc/kernels/cuda_util.cu | 1 - .../csrc/kernels/dropout_kernels.cu | 2004 ++--- .../csrc/kernels/general_kernels.cu | 464 +- .../csrc/kernels/include/dropout.h | 192 +- .../csrc/kernels/include/kernels.h | 27 +- .../csrc/kernels/include/normalize_layer.h | 129 +- .../csrc/kernels/include/softmax.h | 84 +- .../csrc/kernels/normalize_kernels.cu | 2341 +++--- .../csrc/kernels/softmax_kernels.cu | 730 +- .../csrc/kernels/transform_kernels.cu | 626 +- .../cuda_native/csrc/layer_norm_cuda.cpp | 2 +- .../csrc/layer_norm_cuda_kernel.cu | 2 +- .../kernel/cuda_native/csrc/moe_cuda.cpp | 194 +- .../cuda_native/csrc/moe_cuda_kernel.cu | 1318 ++-- .../csrc/multi_tensor_l2norm_kernel.cu | 2 +- .../cuda_native/csrc/multi_tensor_lamb.cu | 2 +- .../csrc/multi_tensor_scale_kernel.cu | 2 +- .../csrc/multi_tensor_sgd_kernel.cu | 2 +- .../csrc/scaled_masked_softmax.cpp | 84 +- .../cuda_native/csrc/scaled_masked_softmax.h | 868 +- .../csrc/scaled_upper_triang_masked_softmax.h | 928 ++- colossalai/kernel/cuda_native/layer_norm.py | 15 +- colossalai/kernel/cuda_native/mha/__init__.py | 2 +- .../kernel/cuda_native/mha/flash_attn_2.py | 46 +- .../kernel/cuda_native/mha/mem_eff_attn.py | 35 +- colossalai/kernel/cuda_native/mha/mha.py | 70 +- colossalai/kernel/cuda_native/mha/utils.py | 8 +- .../kernel/cuda_native/multihead_attention.py | 169 +- .../kernel/cuda_native/scaled_softmax.py | 22 +- colossalai/kernel/jit/__init__.py | 10 +- colossalai/kernel/jit/bias_dropout_add.py | 15 +- colossalai/kernel/jit/bias_gelu.py | 1 - colossalai/kernel/jit/option.py | 21 +- colossalai/kernel/triton/__init__.py | 11 +- colossalai/kernel/triton/context_attention.py | 146 +- .../kernel/triton/copy_kv_cache_dest.py | 36 +- colossalai/kernel/triton/fused_layernorm.py | 35 +- colossalai/kernel/triton/qkv_matmul_kernel.py | 54 +- colossalai/kernel/triton/rms_norm.py | 21 +- .../kernel/triton/rotary_embedding_kernel.py | 46 +- .../kernel/triton/self_attention_nofusion.py | 29 +- colossalai/kernel/triton/softmax.py | 65 +- .../kernel/triton/token_attention_kernel.py | 211 +- colossalai/lazy/__init__.py | 4 +- colossalai/lazy/lazy_init.py | 189 +- colossalai/legacy/__init__.py | 10 +- colossalai/legacy/amp/__init__.py | 5 +- colossalai/legacy/amp/amp_type.py | 6 +- colossalai/legacy/amp/apex_amp/__init__.py | 3 +- colossalai/legacy/amp/apex_amp/apex_amp.py | 2 +- colossalai/legacy/amp/naive_amp/__init__.py | 4 +- .../legacy/amp/naive_amp/_fp16_optimizer.py | 75 +- colossalai/legacy/amp/naive_amp/_utils.py | 2 +- colossalai/legacy/amp/naive_amp/naive_amp.py | 20 +- colossalai/legacy/amp/torch_amp/__init__.py | 9 +- .../legacy/amp/torch_amp/_grad_scaler.py | 86 +- colossalai/legacy/amp/torch_amp/torch_amp.py | 3 +- colossalai/legacy/builder/__init__.py | 2 +- colossalai/legacy/builder/builder.py | 16 +- colossalai/legacy/communication/__init__.py | 34 +- colossalai/legacy/communication/collective.py | 32 +- colossalai/legacy/communication/p2p.py | 244 +- colossalai/legacy/communication/p2p_v2.py | 16 +- colossalai/legacy/communication/ring.py | 17 +- colossalai/legacy/communication/utils.py | 6 +- colossalai/legacy/constants.py | 40 +- colossalai/legacy/context/parallel_context.py | 98 +- colossalai/legacy/context/parallel_mode.py | 37 +- .../process_group_initializer/__init__.py | 12 +- .../initializer_1d.py | 2 +- .../initializer_2d.py | 16 +- .../initializer_2p5d.py | 46 +- .../initializer_3d.py | 24 +- .../initializer_data.py | 2 +- .../initializer_model.py | 2 +- .../initializer_pipeline.py | 18 +- .../initializer_sequence.py | 12 +- .../initializer_tensor.py | 2 +- .../process_group_initializer.py | 11 +- colossalai/legacy/context/random/__init__.py | 13 +- colossalai/legacy/context/random/_helper.py | 3 +- .../legacy/context/random/seed_manager.py | 6 +- colossalai/legacy/core.py | 2 +- colossalai/legacy/engine/__init__.py | 2 +- colossalai/legacy/engine/_base_engine.py | 42 +- .../engine/gradient_accumulation/__init__.py | 21 +- .../_gradient_accumulation.py | 5 +- .../engine/gradient_handler/__init__.py | 8 +- .../_base_gradient_handler.py | 1 - .../_data_parallel_gradient_handler.py | 3 +- .../gradient_handler/_moe_gradient_handler.py | 5 +- .../_pipeline_parallel_gradient_handler.py | 18 +- .../_sequence_parallel_gradient_handler.py | 3 +- .../_zero_gradient_handler.py | 3 +- colossalai/legacy/engine/schedule/__init__.py | 2 +- .../legacy/engine/schedule/_base_schedule.py | 39 +- .../engine/schedule/_non_pipeline_schedule.py | 28 +- .../engine/schedule/_pipeline_schedule.py | 316 +- .../engine/schedule/_pipeline_schedule_v2.py | 40 +- colossalai/legacy/global_variables.py | 54 +- colossalai/legacy/initialize.py | 338 +- colossalai/legacy/nn/_ops/_utils.py | 33 +- colossalai/legacy/nn/layer/base_layer.py | 56 +- .../nn/layer/colossalai_layer/__init__.py | 2 +- .../nn/layer/colossalai_layer/_utils.py | 5 +- .../nn/layer/colossalai_layer/dropout.py | 2 +- .../nn/layer/colossalai_layer/embedding.py | 53 +- .../nn/layer/colossalai_layer/linear.py | 63 +- .../legacy/nn/layer/parallel_1d/__init__.py | 12 +- .../legacy/nn/layer/parallel_1d/_operation.py | 15 +- .../legacy/nn/layer/parallel_1d/_utils.py | 7 +- .../legacy/nn/layer/parallel_1d/layers.py | 551 +- .../legacy/nn/layer/parallel_2d/__init__.py | 11 +- .../legacy/nn/layer/parallel_2d/_operation.py | 303 +- .../legacy/nn/layer/parallel_2d/_utils.py | 16 +- .../legacy/nn/layer/parallel_2d/layers.py | 615 +- .../legacy/nn/layer/parallel_2p5d/__init__.py | 11 +- .../nn/layer/parallel_2p5d/_operation.py | 510 +- .../legacy/nn/layer/parallel_2p5d/_utils.py | 27 +- .../legacy/nn/layer/parallel_2p5d/layers.py | 562 +- .../legacy/nn/layer/parallel_3d/__init__.py | 12 +- .../legacy/nn/layer/parallel_3d/_operation.py | 65 +- .../legacy/nn/layer/parallel_3d/_utils.py | 23 +- .../legacy/nn/layer/parallel_3d/layers.py | 509 +- .../nn/layer/parallel_sequence/__init__.py | 2 +- .../nn/layer/parallel_sequence/_operation.py | 31 +- .../nn/layer/parallel_sequence/layers.py | 119 +- colossalai/legacy/nn/layer/utils/__init__.py | 10 +- colossalai/legacy/nn/layer/utils/common.py | 9 +- .../legacy/nn/layer/vanilla/__init__.py | 9 +- colossalai/legacy/nn/layer/vanilla/layers.py | 93 +- .../legacy/nn/layer/wrapper/__init__.py | 2 +- .../nn/layer/wrapper/pipeline_wrapper.py | 17 +- colossalai/legacy/nn/loss/__init__.py | 19 +- colossalai/legacy/nn/loss/loss_1d.py | 4 +- colossalai/legacy/nn/loss/loss_2d.py | 12 +- colossalai/legacy/nn/loss/loss_2p5d.py | 12 +- colossalai/legacy/nn/loss/loss_3d.py | 6 +- colossalai/legacy/nn/metric/__init__.py | 7 +- colossalai/legacy/nn/metric/accuracy_2d.py | 3 +- colossalai/legacy/nn/metric/accuracy_2p5d.py | 3 +- colossalai/legacy/nn/metric/accuracy_3d.py | 5 +- colossalai/legacy/nn/parallel/__init__.py | 2 +- .../legacy/nn/parallel/data_parallel.py | 37 +- .../legacy/nn/parallel/layers/__init__.py | 20 +- .../layers/cache_embedding/__init__.py | 11 +- .../layers/cache_embedding/base_embedding.py | 9 +- .../layers/cache_embedding/cache_mgr.py | 144 +- .../cache_embedding/cached_embedding.py | 116 +- .../parallel/layers/cache_embedding/copyer.py | 2 +- .../cache_embedding/embedding_config.py | 22 +- .../parallel_cached_embedding.py | 147 +- .../parallel_cached_embedding_tablewise.py | 124 +- ..._cached_embedding_tablewise_split_cache.py | 91 +- .../legacy/nn/parallel/layers/colo_module.py | 19 +- .../legacy/nn/parallel/layers/embedding.py | 15 +- .../legacy/nn/parallel/layers/linear.py | 21 +- .../legacy/nn/parallel/layers/module_utils.py | 20 +- colossalai/legacy/nn/parallel/reducer.py | 12 +- colossalai/legacy/pipeline/__init__.py | 2 +- colossalai/legacy/pipeline/layer_spec.py | 6 +- .../legacy/pipeline/middleware/__init__.py | 2 +- .../pipeline/middleware/adaptor/__init__.py | 2 +- .../legacy/pipeline/middleware/adaptor/fx.py | 21 +- colossalai/legacy/pipeline/middleware/topo.py | 56 +- colossalai/legacy/pipeline/pipelinable.py | 15 +- .../legacy/pipeline/pipeline_process_group.py | 32 +- colossalai/legacy/pipeline/rpc/__init__.py | 2 +- .../legacy/pipeline/rpc/_pipeline_base.py | 343 +- .../legacy/pipeline/rpc/_pipeline_schedule.py | 151 +- colossalai/legacy/pipeline/rpc/utils.py | 50 +- colossalai/legacy/pipeline/utils.py | 26 +- colossalai/legacy/registry/registry.py | 2 +- colossalai/legacy/tensor/__init__.py | 16 +- colossalai/legacy/tensor/compute_spec.py | 2 +- colossalai/legacy/tensor/const.py | 2 +- colossalai/legacy/tensor/dist_spec_mgr.py | 56 +- colossalai/legacy/tensor/distspec.py | 13 +- colossalai/legacy/tensor/process_group.py | 55 +- colossalai/legacy/tensor/tensor_spec.py | 3 +- colossalai/legacy/trainer/__init__.py | 2 +- colossalai/legacy/trainer/_trainer.py | 5 +- colossalai/legacy/trainer/hooks/__init__.py | 15 +- colossalai/legacy/trainer/hooks/_base_hook.py | 46 +- .../legacy/trainer/hooks/_checkpoint_hook.py | 37 +- colossalai/legacy/trainer/hooks/_commons_.py | 4 +- colossalai/legacy/trainer/hooks/_log_hook.py | 126 +- .../trainer/hooks/_lr_scheduler_hook.py | 9 +- .../legacy/trainer/hooks/_metric_hook.py | 71 +- colossalai/legacy/utils/__init__.py | 48 +- .../legacy/utils/activation_checkpoint.py | 35 +- .../legacy/utils/checkpoint/__init__.py | 2 +- .../utils/checkpoint/module_checkpoint.py | 68 +- colossalai/legacy/utils/checkpoint/utils.py | 13 +- colossalai/legacy/utils/checkpointing.py | 66 +- colossalai/legacy/utils/common.py | 77 +- .../legacy/utils/data_sampler/__init__.py | 2 +- .../legacy/utils/data_sampler/base_sampler.py | 1 - .../data_sampler/data_parallel_sampler.py | 57 +- colossalai/legacy/utils/memory.py | 22 +- colossalai/legacy/utils/profiler/extention.py | 1 - .../legacy/utils/profiler/legacy/__init__.py | 2 +- .../utils/profiler/legacy/comm_profiler.py | 99 +- .../utils/profiler/legacy/pcie_profiler.py | 39 +- .../utils/profiler/legacy/prof_utils.py | 34 +- colossalai/legacy/utils/profiler/profiler.py | 62 +- .../profiler/stateful_tensor_mem_extention.py | 25 +- colossalai/legacy/zero/__init__.py | 21 +- colossalai/legacy/zero/gemini/__init__.py | 9 +- .../legacy/zero/gemini/gemini_context.py | 35 +- .../zero/gemini/ophooks/_shard_grad_ophook.py | 2 +- .../gemini/ophooks/_shard_param_ophook.py | 8 +- .../gemini/ophooks/runtime_mem_tracer_hook.py | 14 +- .../legacy/zero/gemini/ophooks/utils.py | 15 +- .../zero/gemini/paramhooks/_param_hookmgr.py | 7 +- .../legacy/zero/gemini/stateful_tensor.py | 14 +- .../legacy/zero/gemini/stateful_tensor_mgr.py | 28 +- .../zero/gemini/tensor_placement_policy.py | 44 +- colossalai/legacy/zero/gemini/tensor_utils.py | 22 +- colossalai/legacy/zero/init_ctx/__init__.py | 2 +- .../legacy/zero/init_ctx/init_context.py | 57 +- .../legacy/zero/shard_utils/__init__.py | 2 +- .../zero/shard_utils/base_shard_strategy.py | 4 +- .../bucket_tensor_shard_strategy.py | 5 +- .../zero/shard_utils/tensor_shard_strategy.py | 8 +- .../legacy/zero/sharded_model/__init__.py | 2 +- .../legacy/zero/sharded_model/_utils.py | 2 +- .../zero/sharded_model/reduce_scatter.py | 26 +- .../zero/sharded_model/sharded_model_v2.py | 218 +- colossalai/legacy/zero/sharded_model/utils.py | 2 +- .../legacy/zero/sharded_model/zero_hook.py | 25 +- .../legacy/zero/sharded_optim/__init__.py | 2 +- .../zero/sharded_optim/sharded_optim_v2.py | 127 +- .../legacy/zero/sharded_param/__init__.py | 2 +- .../zero/sharded_param/sharded_param.py | 4 +- .../zero/sharded_param/sharded_tensor.py | 1 - colossalai/logging/__init__.py | 8 +- colossalai/logging/logger.py | 35 +- colossalai/nn/init.py | 52 +- colossalai/nn/layer/moe/__init__.py | 15 +- colossalai/nn/layer/moe/_operation.py | 12 +- colossalai/nn/layer/moe/checkpoint.py | 12 +- colossalai/nn/layer/moe/experts.py | 40 +- colossalai/nn/layer/moe/layers.py | 64 +- colossalai/nn/layer/moe/routers.py | 461 +- colossalai/nn/layer/moe/utils.py | 139 +- colossalai/nn/layer/utils.py | 5 +- colossalai/nn/lr_scheduler/__init__.py | 19 +- colossalai/nn/lr_scheduler/cosine.py | 31 +- colossalai/nn/lr_scheduler/delayed.py | 39 +- colossalai/nn/lr_scheduler/linear.py | 6 +- colossalai/nn/lr_scheduler/multistep.py | 36 +- colossalai/nn/lr_scheduler/onecycle.py | 52 +- colossalai/nn/lr_scheduler/poly.py | 38 +- colossalai/nn/optimizer/README.md | 2 +- colossalai/nn/optimizer/__init__.py | 2 +- colossalai/nn/optimizer/cpu_adam.py | 141 +- colossalai/nn/optimizer/fused_adam.py | 80 +- colossalai/nn/optimizer/fused_lamb.py | 147 +- colossalai/nn/optimizer/fused_sgd.py | 55 +- colossalai/nn/optimizer/hybrid_adam.py | 135 +- colossalai/nn/optimizer/lamb.py | 30 +- colossalai/nn/optimizer/lars.py | 38 +- colossalai/nn/optimizer/nvme_optimizer.py | 28 +- colossalai/pipeline/__init__.py | 10 +- colossalai/pipeline/p2p.py | 29 +- colossalai/pipeline/schedule/__init__.py | 6 +- colossalai/pipeline/schedule/_utils.py | 23 +- colossalai/pipeline/schedule/base.py | 17 +- .../pipeline/schedule/interleaved_pp.py | 56 +- colossalai/pipeline/schedule/one_f_one_b.py | 70 +- colossalai/pipeline/stage_manager.py | 11 +- colossalai/shardformer/_utils.py | 22 +- .../examples/convergence_benchmark.py | 111 +- colossalai/shardformer/examples/data.py | 33 +- .../examples/performance_benchmark.py | 44 +- colossalai/shardformer/layer/__init__.py | 16 +- colossalai/shardformer/layer/_operation.py | 66 +- colossalai/shardformer/layer/dropout.py | 11 +- colossalai/shardformer/layer/embedding.py | 131 +- colossalai/shardformer/layer/linear.py | 195 +- colossalai/shardformer/layer/loss.py | 14 +- colossalai/shardformer/layer/normalization.py | 53 +- .../shardformer/layer/parallel_module.py | 45 +- .../shardformer/layer/qkv_fused_linear.py | 284 +- colossalai/shardformer/layer/utils.py | 16 +- colossalai/shardformer/modeling/bert.py | 420 +- colossalai/shardformer/modeling/blip2.py | 14 +- colossalai/shardformer/modeling/bloom.py | 280 +- colossalai/shardformer/modeling/chatglm2.py | 181 +- .../chatglm2_6b/configuration_chatglm.py | 54 +- .../modeling/chatglm2_6b/modeling_chatglm.py | 196 +- colossalai/shardformer/modeling/gpt2.py | 507 +- colossalai/shardformer/modeling/jit.py | 3 - colossalai/shardformer/modeling/llama.py | 117 +- colossalai/shardformer/modeling/opt.py | 152 +- colossalai/shardformer/modeling/sam.py | 32 +- colossalai/shardformer/modeling/t5.py | 189 +- colossalai/shardformer/modeling/vit.py | 93 +- colossalai/shardformer/modeling/whisper.py | 227 +- .../shardformer/policies/auto_policy.py | 228 +- .../shardformer/policies/base_policy.py | 24 +- colossalai/shardformer/policies/bert.py | 431 +- colossalai/shardformer/policies/blip2.py | 496 +- colossalai/shardformer/policies/bloom.py | 309 +- colossalai/shardformer/policies/chatglm2.py | 215 +- colossalai/shardformer/policies/gpt2.py | 214 +- colossalai/shardformer/policies/llama.py | 145 +- colossalai/shardformer/policies/opt.py | 210 +- colossalai/shardformer/policies/sam.py | 236 +- colossalai/shardformer/policies/t5.py | 426 +- colossalai/shardformer/policies/vit.py | 185 +- colossalai/shardformer/policies/whisper.py | 420 +- colossalai/shardformer/shard/__init__.py | 2 +- colossalai/shardformer/shard/shard_config.py | 5 +- colossalai/shardformer/shard/sharder.py | 63 +- colossalai/tensor/__init__.py | 13 +- colossalai/tensor/colo_parameter.py | 12 +- colossalai/tensor/colo_tensor.py | 19 +- colossalai/tensor/comm_spec.py | 90 +- colossalai/tensor/d_tensor/__init__.py | 23 +- colossalai/tensor/d_tensor/api.py | 62 +- colossalai/tensor/d_tensor/comm_spec.py | 68 +- colossalai/tensor/d_tensor/layout.py | 12 +- .../tensor/d_tensor/layout_converter.py | 97 +- colossalai/tensor/d_tensor/sharding_spec.py | 96 +- colossalai/tensor/d_tensor/utils.py | 4 +- colossalai/tensor/param_op_hook.py | 3 +- colossalai/tensor/shape_consistency.py | 147 +- colossalai/tensor/sharding_spec.py | 103 +- colossalai/tensor/utils.py | 35 +- colossalai/testing/__init__.py | 18 +- colossalai/testing/comparison.py | 52 +- colossalai/testing/pytest_wrapper.py | 9 +- colossalai/testing/random.py | 2 +- colossalai/testing/utils.py | 15 +- colossalai/utils/__init__.py | 32 +- colossalai/utils/common.py | 2 +- colossalai/utils/cuda.py | 4 +- colossalai/utils/model/utils.py | 19 +- colossalai/utils/moe.py | 5 +- .../multi_tensor_apply/multi_tensor_apply.py | 4 +- colossalai/utils/rank_recorder/README.md | 8 +- colossalai/utils/rank_recorder/__init__.py | 2 +- .../utils/rank_recorder/rank_recorder.py | 59 +- colossalai/utils/tensor_detector/__init__.py | 2 +- colossalai/utils/tensor_detector/readme.md | 3 +- .../utils/tensor_detector/tensor_detector.py | 85 +- colossalai/utils/timer.py | 15 +- colossalai/zero/__init__.py | 11 +- colossalai/zero/gemini/__init__.py | 13 +- colossalai/zero/gemini/chunk/__init__.py | 2 +- colossalai/zero/gemini/chunk/chunk.py | 157 +- colossalai/zero/gemini/chunk/manager.py | 58 +- colossalai/zero/gemini/chunk/search_utils.py | 24 +- colossalai/zero/gemini/chunk/utils.py | 28 +- colossalai/zero/gemini/colo_init_context.py | 69 +- colossalai/zero/gemini/gemini_ddp.py | 283 +- colossalai/zero/gemini/gemini_hook.py | 7 +- colossalai/zero/gemini/gemini_mgr.py | 46 +- colossalai/zero/gemini/gemini_optimizer.py | 281 +- .../zero/gemini/memory_tracer/__init__.py | 18 +- .../memory_tracer/chunk_memstats_collector.py | 4 +- .../gemini/memory_tracer/memory_monitor.py | 1 + .../zero/gemini/memory_tracer/memory_stats.py | 11 +- .../memory_tracer/memstats_collector.py | 15 +- .../memory_tracer/param_runtime_order.py | 1 - .../memory_tracer/runtime_mem_tracer.py | 4 +- .../static_memstats_collector.py | 22 +- colossalai/zero/gemini/memory_tracer/utils.py | 8 +- colossalai/zero/gemini/placement_policy.py | 93 +- colossalai/zero/gemini/utils.py | 27 +- colossalai/zero/low_level/__init__.py | 2 +- colossalai/zero/low_level/_utils.py | 27 +- .../zero/low_level/bookkeeping/__init__.py | 2 +- .../zero/low_level/bookkeeping/base_store.py | 1 - .../low_level/bookkeeping/bucket_store.py | 9 +- .../low_level/bookkeeping/gradient_store.py | 2 - .../low_level/bookkeeping/parameter_store.py | 1 - .../low_level/bookkeeping/tensor_bucket.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 186 +- colossalai/zero/wrapper.py | 57 +- examples/community/fp8/mnist/main.py | 37 +- .../roberta/preprocessing/get_mask.py | 72 +- .../roberta/preprocessing/sentence_split.py | 59 +- .../roberta/preprocessing/tokenize_mask.py | 100 +- .../roberta/pretraining/arguments.py | 93 +- .../pretraining/bert_dataset_provider.py | 1 - .../roberta/pretraining/evaluation.py | 42 +- .../community/roberta/pretraining/loss.py | 5 +- .../roberta/pretraining/model/bert.py | 135 +- .../roberta/pretraining/model/deberta_v2.py | 148 +- .../nvidia_bert_dataset_provider.py | 57 +- .../roberta/pretraining/pretrain_utils.py | 70 +- .../roberta/pretraining/run_pretraining.py | 159 +- .../roberta/pretraining/utils/WandbLog.py | 8 +- .../roberta/pretraining/utils/exp_util.py | 51 +- .../roberta/pretraining/utils/global_vars.py | 22 +- .../roberta/pretraining/utils/logger.py | 14 +- examples/images/diffusion/README.md | 2 +- .../images/diffusion/configs/train_ddp.yaml | 2 +- examples/images/diffusion/ldm/data/base.py | 27 +- examples/images/diffusion/ldm/data/cifar10.py | 75 +- .../images/diffusion/ldm/data/imagenet.py | 123 +- examples/images/diffusion/ldm/data/lsun.py | 116 +- examples/images/diffusion/ldm/data/teyvat.py | 61 +- examples/images/diffusion/ldm/lr_scheduler.py | 31 +- .../diffusion/ldm/models/autoencoder.py | 106 +- .../ldm/models/diffusion/classifier.py | 145 +- .../diffusion/ldm/models/diffusion/ddim.py | 344 +- .../diffusion/ldm/models/diffusion/ddpm.py | 1099 +-- .../models/diffusion/dpm_solver/__init__.py | 2 +- .../models/diffusion/dpm_solver/dpm_solver.py | 564 +- .../models/diffusion/dpm_solver/sampler.py | 69 +- .../diffusion/ldm/models/diffusion/plms.py | 258 +- .../ldm/models/diffusion/sampling_util.py | 8 +- .../images/diffusion/ldm/modules/attention.py | 177 +- .../ldm/modules/diffusionmodules/model.py | 475 +- .../modules/diffusionmodules/openaimodel.py | 169 +- .../ldm/modules/diffusionmodules/upscaling.py | 49 +- .../ldm/modules/diffusionmodules/util.py | 42 +- .../modules/distributions/distributions.py | 37 +- examples/images/diffusion/ldm/modules/ema.py | 13 +- .../diffusion/ldm/modules/encoders/modules.py | 104 +- .../ldm/modules/image_degradation/bsrgan.py | 209 +- .../modules/image_degradation/bsrgan_light.py | 174 +- .../modules/image_degradation/utils_image.py | 278 +- .../images/diffusion/ldm/modules/midas/api.py | 33 +- .../ldm/modules/midas/midas/base_model.py | 2 +- .../ldm/modules/midas/midas/blocks.py | 130 +- .../ldm/modules/midas/midas/dpt_depth.py | 16 +- .../ldm/modules/midas/midas/midas_net.py | 7 +- .../modules/midas/midas/midas_net_custom.py | 87 +- .../ldm/modules/midas/midas/transforms.py | 50 +- .../diffusion/ldm/modules/midas/midas/vit.py | 69 +- .../diffusion/ldm/modules/midas/utils.py | 21 +- examples/images/diffusion/ldm/util.py | 116 +- examples/images/diffusion/main.py | 338 +- .../scripts/download_first_stages.sh | 2 +- examples/images/diffusion/scripts/img2img.py | 76 +- examples/images/diffusion/scripts/inpaint.py | 54 +- examples/images/diffusion/scripts/knn2img.py | 153 +- .../diffusion/scripts/sample_diffusion.py | 158 +- .../scripts/tests/test_checkpoint.py | 20 +- .../diffusion/scripts/tests/test_watermark.py | 8 +- .../diffusion/scripts/train_searcher.py | 166 +- examples/images/diffusion/scripts/txt2img.py | 154 +- examples/images/diffusion/scripts/utils.py | 38 +- examples/images/diffusion/setup.py | 16 +- examples/images/diffusion/train_colossalai.sh | 1 - examples/images/diffusion/train_ddp.sh | 6 +- examples/images/dreambooth/README.md | 4 +- examples/images/dreambooth/debug.py | 8 +- examples/images/dreambooth/inference.py | 4 +- .../images/dreambooth/train_dreambooth.py | 115 +- .../dreambooth/train_dreambooth_colossalai.py | 149 +- .../train_dreambooth_colossalai_lora.py | 142 +- .../dreambooth/train_dreambooth_inpaint.py | 150 +- examples/images/resnet/eval.py | 11 +- examples/images/resnet/requirements.txt | 2 +- examples/images/resnet/train.py | 99 +- examples/images/vit/args.py | 96 +- examples/images/vit/data.py | 16 +- examples/images/vit/requirements.txt | 2 +- examples/images/vit/vit_benchmark.py | 51 +- examples/images/vit/vit_train_demo.py | 132 +- examples/inference/bench_bloom.py | 18 +- examples/inference/bench_llama.py | 23 +- examples/language/bert/benchmark.py | 56 +- examples/language/bert/benchmark_utils.py | 9 +- examples/language/bert/data.py | 16 +- examples/language/bert/finetune.py | 126 +- .../gpt/experiments/auto_offload/model_zoo.py | 28 +- .../experiments/auto_offload/requirements.txt | 2 +- .../auto_offload/train_gpt_offload.py | 37 +- .../auto_parallel/auto_parallel_with_gpt.py | 21 +- .../experiments/auto_parallel/gpt_modules.py | 24 +- .../pipeline_parallel/model_zoo.py | 33 +- .../pipeline_parallel/train_gpt_pp.py | 79 +- .../language/gpt/gemini/commons/model_zoo.py | 33 +- examples/language/gpt/gemini/commons/utils.py | 13 +- .../language/gpt/gemini/train_gpt_demo.py | 44 +- .../language/gpt/hybridparallelism/data.py | 16 +- .../gpt/hybridparallelism/finetune.py | 130 +- .../titans/configs/gpt2_small_zero3_pp1d.py | 8 +- .../gpt/titans/configs/gpt3_zero3_pp1d.py | 8 +- .../language/gpt/titans/dataset/webtext.py | 15 +- examples/language/gpt/titans/model/embed.py | 184 +- examples/language/gpt/titans/model/gpt1d.py | 252 +- .../gpt/titans/model/pipeline_gpt1d.py | 321 +- examples/language/gpt/titans/train_gpt.py | 91 +- examples/language/llama2/attn.py | 9 +- examples/language/llama2/benchmark.py | 214 +- examples/language/llama2/data_utils.py | 75 +- examples/language/llama2/finetune.py | 256 +- examples/language/llama2/model_utils.py | 8 +- .../language/llama2/performance_evaluator.py | 31 +- examples/language/llama2/pretrain.py | 293 +- examples/language/opt/args.py | 76 +- examples/language/opt/data.py | 29 +- examples/language/opt/opt_benchmark.py | 18 +- examples/language/opt/opt_train_demo.py | 76 +- examples/language/opt/run_benchmark.sh | 2 +- .../palm_pytorch/autoregressive_wrapper.py | 2 - .../palm/palm_pytorch/palm_pytorch.py | 28 +- examples/language/palm/train.py | 48 +- examples/tutorial/README.md | 2 +- .../auto_parallel/auto_ckpt_batchsize_test.py | 16 +- .../auto_parallel/auto_ckpt_solver_test.py | 34 +- .../auto_parallel_with_resnet.py | 9 +- .../tutorial/auto_parallel/bench_utils.py | 64 +- examples/tutorial/auto_parallel/setup.py | 12 +- examples/tutorial/download_cifar10.py | 4 +- examples/tutorial/hybrid_parallel/config.py | 6 +- examples/tutorial/hybrid_parallel/train.py | 47 +- .../tutorial/large_batch_optimizer/train.py | 32 +- .../tutorial/new_api/cifar_resnet/eval.py | 11 +- .../tutorial/new_api/cifar_resnet/train.py | 101 +- examples/tutorial/new_api/cifar_vit/train.py | 124 +- examples/tutorial/new_api/glue_bert/data.py | 16 +- .../tutorial/new_api/glue_bert/finetune.py | 80 +- examples/tutorial/opt/inference/batch.py | 29 +- .../opt/inference/benchmark/locustfile.py | 9 +- examples/tutorial/opt/inference/cache.py | 4 +- .../tutorial/opt/inference/opt_fastapi.py | 101 +- examples/tutorial/opt/inference/opt_server.py | 119 +- .../script/process-opt-175b/README.md | 1 - .../script/process-opt-175b/convert_ckpt.py | 39 +- .../script/process-opt-175b/flat-meta.json | 6945 ++++++++++++++++- .../inference/script/processing_ckpt_66b.py | 24 +- examples/tutorial/opt/opt/colossalai_zero.py | 8 +- examples/tutorial/opt/opt/context.py | 2 +- examples/tutorial/opt/opt/run_clm.py | 154 +- examples/tutorial/sequence_parallel/config.py | 6 +- .../sequence_parallel/data/__init__.py | 52 +- .../sequence_parallel/data/bert_helper.py | 45 +- .../data/datasets/bert_dataset.py | 153 +- .../data/datasets/blendable_dataset.py | 18 +- .../data/datasets/builder.py | 134 +- .../data/datasets/data_samplers.py | 85 +- .../data/datasets/dataset_utils.py | 274 +- .../data/datasets/helpers.cpp | 1163 ++- .../data/datasets/ict_dataset.py | 67 +- .../data/datasets/indexed_dataset.py | 147 +- .../datasets/test/test_indexed_dataset.py | 59 +- .../data/dummy_dataloader.py | 55 +- .../data/tokenizer/__init__.py | 1 - .../data/tokenizer/bert_tokenization.py | 67 +- .../data/tokenizer/tokenizer.py | 69 +- .../sequence_parallel/loss_func/bert_loss.py | 5 - .../loss_func/cross_entropy.py | 5 +- .../sequence_parallel/loss_func/utils.py | 17 +- .../lr_scheduler/annealing_lr.py | 104 +- .../sequence_parallel/model/__init__.py | 2 - .../tutorial/sequence_parallel/model/bert.py | 126 +- .../model/layers/__init__.py | 2 +- .../model/layers/bert_layer.py | 34 +- .../sequence_parallel/model/layers/dropout.py | 4 +- .../model/layers/embedding.py | 35 +- .../sequence_parallel/model/layers/head.py | 5 - .../model/layers/init_method.py | 4 +- .../sequence_parallel/model/layers/linear.py | 25 +- .../sequence_parallel/model/layers/mlp.py | 20 +- .../sequence_parallel/model/layers/pooler.py | 1 + .../model/layers/preprocess.py | 10 +- examples/tutorial/sequence_parallel/train.py | 115 +- op_builder/__init__.py | 29 +- op_builder/builder.py | 74 +- op_builder/cpu_adam.py | 24 +- op_builder/fused_optim.py | 25 +- op_builder/layernorm.py | 12 +- op_builder/moe.py | 20 +- op_builder/multi_head_attn.py | 32 +- op_builder/scaled_masked_softmax.py | 27 +- .../scaled_upper_triangle_masked_softmax.py | 24 +- op_builder/utils.py | 61 +- setup.py | 133 +- tests/components_to_test/__init__.py | 16 +- tests/components_to_test/albert.py | 51 +- tests/components_to_test/beit.py | 30 +- tests/components_to_test/bert.py | 61 +- tests/components_to_test/gpt2.py | 60 +- .../components_to_test/hanging_param_model.py | 5 +- tests/components_to_test/inline_op_model.py | 7 +- tests/components_to_test/nested_model.py | 6 +- tests/components_to_test/registry.py | 3 +- .../repeated_computed_layers.py | 4 +- tests/components_to_test/resnet.py | 17 +- tests/components_to_test/simple_net.py | 5 +- .../utils/dummy_data_generator.py | 1 - tests/kit/model_zoo/__init__.py | 2 +- tests/kit/model_zoo/diffusers/diffusers.py | 76 +- tests/kit/model_zoo/registry.py | 21 +- tests/kit/model_zoo/timm/timm.py | 316 +- tests/kit/model_zoo/torchaudio/torchaudio.py | 135 +- tests/kit/model_zoo/torchrec/torchrec.py | 126 +- .../kit/model_zoo/torchvision/torchvision.py | 202 +- tests/kit/model_zoo/transformers/albert.py | 98 +- tests/kit/model_zoo/transformers/bert.py | 453 +- tests/kit/model_zoo/transformers/blip2.py | 30 +- tests/kit/model_zoo/transformers/bloom.py | 101 +- tests/kit/model_zoo/transformers/chatglm2.py | 60 +- tests/kit/model_zoo/transformers/gpt.py | 124 +- tests/kit/model_zoo/transformers/llama.py | 63 +- tests/kit/model_zoo/transformers/opt.py | 59 +- tests/kit/model_zoo/transformers/sam.py | 24 +- tests/kit/model_zoo/transformers/t5.py | 46 +- tests/kit/model_zoo/transformers/vit.py | 48 +- tests/kit/model_zoo/transformers/whisper.py | 52 +- .../test_fx/test_bias_addition.py | 58 +- tests/test_analyzer/test_fx/test_mod_dir.py | 35 +- .../test_analyzer/test_fx/test_nested_ckpt.py | 7 +- .../test_analyzer/test_fx/test_shape_prop.py | 20 +- .../test_fx/test_symbolic_profile.py | 11 +- .../test_subclasses/test_aten.py | 45 +- .../test_subclasses/test_flop_tensor.py | 49 +- .../test_subclasses/test_meta_mode.py | 21 +- .../test_C_solver_consistency.py | 19 +- .../test_ckpt_torchvision.py | 37 +- .../test_ckpt_solvers/test_linearize.py | 26 +- .../test_offload/model_utils.py | 51 +- .../test_offload/test_perf.py | 58 +- .../test_offload/test_solver.py | 15 +- .../test_pass/test_node_converting_pass.py | 12 +- .../test_size_value_converting_pass.py | 12 +- .../test_bias_addition_forward.py | 26 +- .../test_tensor_shard/test_broadcast.py | 14 +- .../test_tensor_shard/test_checkpoint.py | 26 +- .../test_compatibility_with_ddp.py | 42 +- .../test_compatibility_with_gemini.py | 50 +- .../test_find_repeat_block.py | 19 +- .../test_tensor_shard/test_gpt/gpt_modules.py | 22 +- .../test_gpt/test_runtime_with_gpt_modules.py | 68 +- .../test_gpt/test_solver_with_gpt_module.py | 25 +- .../test_liveness_analysis.py | 11 +- .../test_metainfo/test_activation_metainfo.py | 39 +- .../test_binary_elementwise_metainfo.py | 24 +- .../test_metainfo/test_conv_metainfo.py | 51 +- .../test_metainfo/test_embedding_metainfo.py | 20 +- .../test_metainfo/test_linear_metainfo.py | 51 +- .../test_metainfo/test_matmul_metainfo.py | 47 +- .../test_metainfo/test_norm_metainfo.py | 55 +- .../test_metainfo/test_pooling_metainfo.py | 42 +- .../test_metainfo/test_tensor_metainfo.py | 23 +- .../test_metainfo/test_where_metainfo.py | 22 +- .../test_tensor_shard/test_metainfo/utils.py | 87 +- .../test_node_handler/test_addbmm_handler.py | 174 +- .../test_node_handler/test_addmm_handler.py | 94 +- .../test_batch_norm_handler.py | 60 +- .../test_bias_linear_function_node.py | 107 +- .../test_bias_linear_module_node.py | 105 +- .../test_binary_elementwise_handler.py | 151 +- .../test_node_handler/test_bmm_handler.py | 136 +- .../test_node_handler/test_conv_handler.py | 199 +- .../test_default_reshape_handler.py | 41 +- .../test_embedding_handler.py | 198 +- .../test_node_handler/test_getattr_handler.py | 37 +- .../test_node_handler/test_getitem_handler.py | 94 +- .../test_layer_norm_handler.py | 62 +- .../test_node_handler/test_linear_handler.py | 229 +- .../test_node_handler/test_matmul_handler.py | 79 +- .../test_norm_pooling_handler.py | 26 +- .../test_node_handler/test_output_handler.py | 25 +- .../test_permute_and_transpose_handler.py | 365 +- .../test_placeholder_handler.py | 37 +- .../test_node_handler/test_shard_option.py | 74 +- .../test_node_handler/test_softmax_handler.py | 169 +- .../test_node_handler/test_split_handler.py | 263 +- .../test_node_handler/test_sum_handler.py | 255 +- .../test_tensor_constructor.py | 17 +- .../test_unary_element_wise_handler.py | 41 +- .../test_node_handler/test_view_handler.py | 261 +- .../test_node_handler/test_where_handler.py | 55 +- .../test_node_handler/utils.py | 102 +- .../test_solver_with_resnet_v2.py | 21 +- .../benchmark_autochunk_alphafold.py | 5 +- .../test_autochunk_alphafold_utils.py | 7 +- .../test_autochunk_evoformer_block.py | 53 +- .../test_autochunk_evoformer_stack.py | 47 +- .../test_autochunk_extramsa_block.py | 43 +- .../benchmark_autochunk_diffuser.py | 17 +- .../test_autochunk_diffuser_utils.py | 10 +- .../test_autochunk_unet.py | 2 + .../benchmark_autochunk_transformer.py | 17 +- .../test_autochunk_gpt.py | 23 +- .../test_autochunk_transformer_utils.py | 34 +- .../test_autochunk_vit/test_autochunk_vit.py | 3 +- .../test_autochunk_vit_utils.py | 8 +- tests/test_booster/test_accelerator.py | 2 +- .../test_mixed_precision/test_fp16_torch.py | 8 +- .../test_plugin/test_3d_plugin.py | 31 +- .../test_plugin/test_dp_plugin_base.py | 10 +- .../test_plugin/test_gemini_plugin.py | 82 +- .../test_plugin/test_low_level_zero_plugin.py | 22 +- .../test_plugin/test_torch_ddp_plugin.py | 14 +- .../test_plugin/test_torch_fsdp_plugin.py | 22 +- .../test_gemini_checkpoint_io.py | 77 +- .../test_gemini_torch_compability.py | 47 +- .../test_general_checkpoint_io.py | 33 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 92 +- .../test_low_level_zero_checkpoint_io.py | 15 +- .../test_plugins_huggingface_compatibility.py | 31 +- .../test_torch_ddp_checkpoint_io.py | 14 +- .../test_torch_fsdp_checkpoint_io.py | 19 +- tests/test_checkpoint_io/utils.py | 2 +- .../test_cluster/test_device_mesh_manager.py | 8 +- tests/test_cluster/test_process_group_mesh.py | 56 +- tests/test_config/sample_config.py | 20 +- tests/test_config/test_load_config.py | 13 +- tests/test_device/test_alpha_beta.py | 6 +- tests/test_device/test_device_mesh.py | 22 +- tests/test_device/test_extract_alpha_beta.py | 6 +- tests/test_device/test_init_logical_pg.py | 4 +- .../test_search_logical_device_mesh.py | 6 +- .../test_activation_checkpoint_codegen.py | 61 +- ...st_nested_activation_checkpoint_codegen.py | 75 +- .../test_codegen/test_offload_codegen.py | 76 +- tests/test_fx/test_coloproxy.py | 9 +- tests/test_fx/test_comm_size_compute.py | 5 +- tests/test_fx/test_graph_manipulation.py | 10 +- tests/test_fx/test_meta/test_aten.py | 45 +- tests/test_fx/test_meta/test_backward.py | 31 +- tests/test_fx/test_meta/test_meta_trace.py | 31 +- tests/test_fx/test_meta_info_prop.py | 14 +- tests/test_fx/test_parallel_1d.py | 7 +- .../test_pipeline/test_hf_model/hf_utils.py | 25 +- .../test_hf_model/test_albert.py | 18 +- .../test_pipeline/test_hf_model/test_bert.py | 12 +- .../test_pipeline/test_hf_model/test_gpt.py | 6 +- .../test_pipeline/test_hf_model/test_opt.py | 4 +- .../test_pipeline/test_hf_model/test_t5.py | 4 +- .../test_timm_model/test_timm.py | 17 +- .../test_timm_model/timm_utils.py | 13 +- .../test_pipeline/test_topo/test_topo.py | 9 +- .../test_pipeline/test_topo/topo_utils.py | 15 +- .../test_torchvision/test_torchvision.py | 19 +- tests/test_fx/test_pipeline_passes.py | 7 +- tests/test_fx/test_profiler/gpt_utils.py | 34 +- .../test_profiler_meta_info_prop.py | 81 +- .../test_activation_checkpoint_annotation.py | 15 +- .../test_tracer/test_bias_addition_module.py | 23 +- .../test_fx/test_tracer/test_control_flow.py | 19 +- .../test_tracer/test_functional_conv.py | 2 +- .../test_hf_model/hf_tracer_utils.py | 7 +- .../test_hf_model/test_hf_albert.py | 6 +- .../test_tracer/test_hf_model/test_hf_bert.py | 8 +- .../test_hf_model/test_hf_diffuser.py | 10 +- .../test_tracer/test_hf_model/test_hf_gpt.py | 10 +- .../test_tracer/test_hf_model/test_hf_opt.py | 8 +- .../test_tracer/test_hf_model/test_hf_t5.py | 10 +- .../test_tracer/test_patched_module.py | 515 +- tests/test_fx/test_tracer/test_patched_op.py | 36 +- .../test_timm_model/test_timm_model.py | 15 +- .../test_torchaudio_model.py | 10 +- .../test_torchaudio_model/torchaudio_utils.py | 7 +- .../test_torchrec_model/test_deepfm_model.py | 22 +- .../test_torchrec_model/test_dlrm_model.py | 24 +- .../test_torchvision_model.py | 8 +- tests/test_infer/_utils.py | 26 +- tests/test_infer/test_bloom_infer.py | 27 +- tests/test_infer/test_infer_engine.py | 36 +- tests/test_infer/test_kvcache_manager.py | 39 +- tests/test_infer/test_llama_infer.py | 31 +- .../test_infer_ops/cuda/test_vllm_rmsnorm.py | 14 +- .../cuda/test_vllm_rotary_embedding.py | 47 +- tests/test_infer_ops/triton/kernel_utils.py | 9 +- .../triton/test_bloom_context_attention.py | 20 +- .../triton/test_copy_kv_dest.py | 18 +- .../triton/test_layernorm_triton.py | 21 +- .../triton/test_llama_context_attention.py | 20 +- .../triton/test_rotary_embedding.py | 21 +- .../triton/test_self_attention_nonfusion.py | 104 +- tests/test_infer_ops/triton/test_softmax.py | 25 +- .../triton/test_token_attn_1.py | 18 +- .../triton/test_token_attn_2.py | 26 +- .../triton/test_token_attn_fwd.py | 14 +- .../triton/test_token_softmax.py | 12 +- tests/test_lazy/lazy_init_utils.py | 38 +- tests/test_lazy/test_models.py | 18 +- tests/test_legacy/test_amp/test_naive_fp16.py | 15 +- tests/test_legacy/test_amp/test_torch_fp16.py | 16 +- .../test_comm/test_boardcast_send_recv_v2.py | 16 +- tests/test_legacy/test_comm/test_comm.py | 24 +- .../test_comm/test_object_list_p2p.py | 16 +- .../test_comm/test_object_list_p2p_v2.py | 12 +- .../test_context/configs/parallel_2d_init.py | 2 +- .../configs/parallel_2p5d_init.py | 2 +- .../test_context/configs/parallel_3d_init.py | 2 +- .../test_context/test_hybrid_parallel.py | 32 +- .../test_data/test_cifar10_dataset.py | 5 +- .../test_data/test_data_parallel_sampler.py | 28 +- .../test_deterministic_dataloader.py | 19 +- tests/test_legacy/test_engine/test_engine.py | 34 +- .../test_engine/test_gradient_accumluation.py | 57 +- .../test_1d/checks_1d/check_layer_1d.py | 44 +- .../test_layers/test_1d/test_1d.py | 8 +- .../test_2d/checks_2d/check_layer_2d.py | 48 +- .../test_2d/checks_2d/check_operation_2d.py | 107 +- .../test_layers/test_2d/test_2d.py | 8 +- .../test_2p5d/checks_2p5d/check_layer_2p5d.py | 70 +- .../checks_2p5d/check_operation_2p5d.py | 113 +- .../test_layers/test_2p5d/test_2p5d.py | 14 +- .../test_3d/checks_3d/check_layer_3d.py | 241 +- .../test_layers/test_3d/test_3d.py | 6 +- .../test_layers/test_cache_embedding.py | 162 +- .../checks_seq/check_layer_seq.py | 7 +- .../test_sequence/test_sequence.py | 32 +- .../test_pipeline/rpc_test_utils.py | 71 +- .../test_pipeline/test_cuda_rpc_chimera.py | 18 +- .../test_pipeline/test_cuda_rpc_optimizer.py | 21 +- .../test_pipeline/test_cuda_rpc_pipeline.py | 17 +- .../test_cuda_rpc_value_correctness.py | 19 +- .../test_pipeline/test_middleware_1f1b.py | 33 +- .../test_pipeline/test_pipelinable.py | 5 +- .../test_pipeline_process_group.py | 20 +- .../test_tensor/common_utils/_utils.py | 19 +- .../test_tensor/core/test_dist_spec_mgr.py | 6 +- .../test_legacy/test_tensor/test_parameter.py | 9 +- .../test_trainer/test_pipeline/test_p2p.py | 30 +- .../test_pipeline/test_pipeline_schedule.py | 25 +- .../test_trainer_with_non_pipe_schedule.py | 34 +- .../test_trainer_with_pipe_schedule.py | 55 +- .../test_activation_checkpointing.py | 14 +- .../test_checkpoint/test_checkpoint_1d.py | 4 +- .../test_checkpoint/test_checkpoint_2d.py | 4 +- .../test_checkpoint/test_checkpoint_2p5d.py | 4 +- .../test_checkpoint/test_checkpoint_3d.py | 4 +- tests/test_legacy/test_utils/test_memory.py | 4 +- .../test_utils/test_norm_gradient_clipping.py | 26 +- tests/test_legacy/test_zero/test_commons.py | 18 +- tests/test_moe/test_grad_handler.py | 6 +- tests/test_moe/test_kernel.py | 14 +- tests/test_moe/test_moe_checkpoint.py | 10 +- tests/test_moe/test_moe_colo_init.py | 11 +- tests/test_moe/test_moe_group.py | 6 +- tests/test_moe/test_moe_zero_init.py | 46 +- tests/test_moe/test_moe_zero_model.py | 12 +- tests/test_moe/test_moe_zero_optim.py | 43 +- tests/test_optimizer/test_adam_kernel.py | 92 +- tests/test_optimizer/test_adam_optim.py | 33 +- tests/test_optimizer/test_nvme.py | 22 +- tests/test_pipeline/test_p2p_communication.py | 12 +- .../test_t5_pipeline_utils.py | 42 +- .../test_whisper_pipeline_utils.py | 46 +- .../test_schedule/test_interleaved.py | 44 +- .../test_schedule/test_oneF_oneB.py | 28 +- .../test_pipeline_schedule_utils.py | 24 +- tests/test_pipeline/test_stage_manager.py | 4 +- .../test_layer/test_dist_crossentropy.py | 15 +- .../test_layer/test_dropout.py | 4 +- .../test_layer/test_embedding.py | 6 +- .../test_gpt2_qkv_fused_linear_1d.py | 26 +- .../test_layer/test_layernorm.py | 6 +- .../test_layer/test_linear_1d.py | 54 +- .../test_layer/test_qkv_fused_linear_1d.py | 15 +- .../test_vocab_parallel_embedding_1d.py | 12 +- tests/test_shardformer/test_model/_utils.py | 212 +- .../test_model/test_shard_bert.py | 202 +- .../test_model/test_shard_blip2.py | 39 +- .../test_model/test_shard_bloom.py | 198 +- .../test_model/test_shard_chatglm2.py | 214 +- .../test_model/test_shard_gpt2.py | 228 +- .../test_model/test_shard_llama.py | 231 +- .../test_model/test_shard_opt.py | 213 +- .../test_model/test_shard_sam.py | 31 +- .../test_model/test_shard_t5.py | 205 +- .../test_model/test_shard_vit.py | 209 +- .../test_model/test_shard_whisper.py | 202 +- tests/test_shardformer/test_shard_utils.py | 1 - tests/test_shardformer/test_with_torch_ddp.py | 7 +- tests/test_tensor/test_comm_spec_apply.py | 19 +- .../test_dtensor/test_comm_spec.py | 30 +- .../test_tensor/test_dtensor/test_dtensor.py | 17 +- .../test_dtensor_sharding_spec.py | 7 +- .../test_dtensor/test_layout_converter.py | 36 +- tests/test_tensor/test_mix_gather.py | 150 +- tests/test_tensor/test_shape_consistency.py | 63 +- .../test_shape_consistency_apply.py | 6 +- tests/test_tensor/test_sharding_spec.py | 4 +- tests/test_utils/test_flash_attention.py | 34 +- .../test_zero/test_gemini/test_chunk_mgrv2.py | 35 +- tests/test_zero/test_gemini/test_chunkv2.py | 42 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 41 +- .../test_gemini/test_gemini_use_rmt.py | 45 +- tests/test_zero/test_gemini/test_grad_clip.py | 63 +- tests/test_zero/test_gemini/test_inference.py | 37 +- tests/test_zero/test_gemini/test_optim.py | 110 +- .../test_gemini/test_runtime_mem_tracer.py | 8 +- tests/test_zero/test_gemini/test_search.py | 36 +- .../test_gemini/test_zeroddp_state_dict.py | 51 +- .../test_gemini/test_zerooptim_state_dict.py | 44 +- .../test_zero/test_low_level/test_grad_acc.py | 29 +- .../test_zero/test_low_level/test_zero1_2.py | 29 +- .../test_low_level/test_zero_ckpt.py | 19 +- 1268 files changed, 50252 insertions(+), 38659 deletions(-) delete mode 100644 .flake8 delete mode 100644 .style.yapf diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 229856aa4366..000000000000 --- a/.flake8 +++ /dev/null @@ -1,22 +0,0 @@ -[flake8] -ignore = - ;W503 line break before binary operator - W503, - ;E203 whitespace before ':' - E203, - -; exclude file -exclude = - .tox, - .git, - __pycache__, - build, - dist, - *.pyc, - *.egg-info, - .cache, - .eggs - -max-line-length = 120 - -per-file-ignores = __init__.py:F401 diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py index 1aa7283e9e52..1e7f0c33a785 100644 --- a/.github/workflows/scripts/check_doc_i18n.py +++ b/.github/workflows/scripts/check_doc_i18n.py @@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2): # If the corresponding item doesn't exist in the second directory, the directories are different if not os.path.exists(item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If the corresponding item is a directory, we compare the two directories recursively if os.path.isdir(item_path1) and os.path.isdir(item_path2): if not compare_dirs(item_path1, item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # both are files @@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2): # If the corresponding item is not a file or a directory, the directories are different else: - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If all items are the same, the directories are the same return True -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.") + parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.") args = parser.parse_args() i18n_folders = os.listdir(args.directory) @@ -56,7 +56,7 @@ def compare_dirs(dir1, dir2): for i in range(1, len(i18n_folders)): dir1 = i18n_folders[0] dir2 = i18n_folders[i] - print(f'comparing {dir1} vs {dir2}') + print(f"comparing {dir1} vs {dir2}") match = compare_dirs(i18n_folders[0], i18n_folders[i]) if not match: diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py index 5bec96187e0c..91778f692cc6 100644 --- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -4,7 +4,7 @@ def check_inputs(input_list): for path in input_list: - real_path = os.path.join('examples', path) + real_path = os.path.join("examples", path) if not os.path.exists(real_path): return False return True @@ -12,16 +12,16 @@ def check_inputs(input_list): def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + parser.add_argument("-f", "--fileNameList", type=str, help="List of file names") args = parser.parse_args() name_list = args.fileNameList.split(",") is_correct = check_inputs(name_list) if is_correct: - print('success') + print("success") else: - print('failure') + print("failure") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py index 83eff644e315..95a3d24c9a78 100644 --- a/.github/workflows/scripts/example_checks/check_example_weekly.py +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -17,21 +17,21 @@ def show_files(path, all_files): def join(input_list, sep=None): - return (sep or ' ').join(input_list) + return (sep or " ").join(input_list) def main(): - contents = show_files('examples/', []) + contents = show_files("examples/", []) all_loc = [] for file_loc in contents: - split_loc = file_loc.split('/') + split_loc = file_loc.split("/") # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. if len(split_loc) >= 4: - re_loc = '/'.join(split_loc[1:3]) + re_loc = "/".join(split_loc[1:3]) if re_loc not in all_loc: all_loc.append(re_loc) print(all_loc) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py index c69d95a552e9..95f671dfb32b 100644 --- a/.github/workflows/scripts/example_checks/detect_changed_example.py +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -3,7 +3,7 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files") args = parser.parse_args() name_list = args.fileNameList.split(":") folder_need_check = set() @@ -15,10 +15,10 @@ def main(): # - application # - file if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: - folder_need_check.add('/'.join(loc.split("/")[1:3])) + folder_need_check.add("/".join(loc.split("/")[1:3])) # Output the result using print. Then the shell can get the values. print(list(folder_need_check)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py index 2884e38dd3dd..412b14c7b283 100644 --- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py +++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py @@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]: # prepare header headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } res = requests.get(url, headers=headers).json() repo_list = [] for item in res: - repo_list.append(item['name']) + repo_list.append(item["name"]) return repo_list @@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: """ # prepare header headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } user_engagement_count = {} @@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: # do pagination to the API page = 1 while True: - comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}' + comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}" comment_response = requests.get(comment_api, headers=headers).json() if len(comment_response) == 0: break else: for item in comment_response: - comment_author_relationship = item['author_association'] - if comment_author_relationship != 'MEMBER': + comment_author_relationship = item["author_association"] + if comment_author_relationship != "MEMBER": # if the comment is not made by our member # we don't count this comment towards user engagement continue - issue_id = item['issue_url'].split('/')[-1] - issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}' + issue_id = item["issue_url"].split("/")[-1] + issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}" issue_response = requests.get(issue_api, headers=headers).json() - issue_author_relationship = issue_response['author_association'] + issue_author_relationship = issue_response["author_association"] - if issue_author_relationship != 'MEMBER': + if issue_author_relationship != "MEMBER": # this means that the issue/PR is not created by our own people # any comments in this issue/PR by our member will be counted towards the leaderboard - member_name = item['user']['login'] + member_name = item["user"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 @@ -153,7 +153,7 @@ def _generate_discussion_query(num, cursor: str = None): if cursor is None: offset_str = "" else: - offset_str = f", after: \"{cursor}\"" + offset_str = f', after: "{cursor}"' query = f""" {{ repository(owner: "{org_name}", name: "{repo_name}"){{ @@ -182,7 +182,7 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor: if cursor is None: offset_str = "" else: - offset_str = f", before: \"{cursor}\"" + offset_str = f', before: "{cursor}"' query = f""" {{ repository(owner: "{org_name}", name: "{repo_name}"){{ @@ -220,8 +220,8 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor: # a utility function to make call to Github GraphQL API def _call_graphql_api(query): headers = {"Authorization": f"Bearer {github_token}"} - json_data = {'query': query} - response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers) + json_data = {"query": query} + response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers) data = response.json() return data @@ -234,21 +234,21 @@ def _call_graphql_api(query): data = _call_graphql_api(query) found_discussion_out_of_time_range = False - edges = data['data']['repository']['discussions']['edges'] + edges = data["data"]["repository"]["discussions"]["edges"] if len(edges) == 0: break else: # keep the discussion whose author is not a member for edge in edges: # print the discussion title - discussion = edge['node'] - discussion_updated_at = str2datetime(discussion['updatedAt']) + discussion = edge["node"] + discussion_updated_at = str2datetime(discussion["updatedAt"]) # check if the updatedAt is within the last 7 days # if yes, add it to discussion_numbers if discussion_updated_at > since: - if discussion['authorAssociation'] != 'MEMBER': - discussion_numbers.append(discussion['number']) + if discussion["authorAssociation"] != "MEMBER": + discussion_numbers.append(discussion["number"]) else: found_discussion_out_of_time_range = True @@ -256,7 +256,7 @@ def _call_graphql_api(query): break else: # update cursor - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] # get the discussion comments and replies made by our member user_engagement_count = {} @@ -269,42 +269,42 @@ def _call_graphql_api(query): data = _call_graphql_api(query) # get the comments - edges = data['data']['repository']['discussion']['comments']['edges'] + edges = data["data"]["repository"]["discussion"]["comments"]["edges"] # update the cursor if len(edges) == 0: break else: # update cursor for pagination - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] for edge in edges: - comment = edge['node'] - if comment['authorAssociation'] == 'MEMBER': + comment = edge["node"] + if comment["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days # if yes, add it to user_engagement_count - comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if comment_updated_at > since: - member_name = comment['author']['login'] + member_name = comment["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: user_engagement_count[member_name] = 1 # get the replies - reply_edges = comment['replies']['edges'] + reply_edges = comment["replies"]["edges"] if len(reply_edges) == 0: continue else: for reply_edge in reply_edges: - reply = reply_edge['node'] - if reply['authorAssociation'] == 'MEMBER': + reply = reply_edge["node"] + if reply["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days # if yes, add it to discussion_numbers - reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if reply_updated_at > since: - member_name = reply['author']['login'] + member_name = reply["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: @@ -312,7 +312,9 @@ def _call_graphql_api(query): return user_engagement_count -def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool: +def generate_user_engagement_leaderboard_image( + github_token: str, org_name: str, repo_list: List[str], output_path: str +) -> bool: """ Generate the user engagement leaderboard image for stats within the last 7 days @@ -335,16 +337,19 @@ def _update_count(counter): else: total_engagement_count[name] = count - for repo_name in repo_list: print(f"Fetching user engagement count for {repo_name}/{repo_name}") - issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str) - discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime) + issue_pr_engagement_count = get_issue_pull_request_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str + ) + discussion_engagement_count = get_discussion_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime + ) # update the total engagement count _update_count(issue_pr_engagement_count) _update_count(discussion_engagement_count) - + # prepare the data for plotting x = [] y = [] @@ -363,7 +368,7 @@ def _update_count(counter): # plot the leaderboard xlabel = f"Number of Comments made (since {start_datetime_str})" ylabel = "Member" - title = 'Active User Engagement Leaderboard' + title = "Active User Engagement Leaderboard" plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: @@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou """ # request to the Github API to get the users who have contributed in the last 7 days headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } counter = Counter() start_datetime = get_utc_time_one_week_ago() def _get_url(org_name, repo_name, page): - return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed' + return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed" def _iterate_by_page(org_name, repo_name): page = 1 @@ -415,8 +420,8 @@ def _iterate_by_page(org_name, repo_name): # count the pull request and author from response for pr_data in response: - merged_at = pr_data['merged_at'] - author = pr_data['user']['login'] + merged_at = pr_data["merged_at"] + author = pr_data["user"]["login"] if merged_at is None: continue @@ -439,7 +444,7 @@ def _iterate_by_page(org_name, repo_name): _iterate_by_page(org_name, repo_name) # convert unix timestamp to Beijing datetime - bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai')) + bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai")) bj_start_datetime_str = datetime2str(bj_start_datetime) contribution_list = counter.to_sorted_list() @@ -452,7 +457,7 @@ def _iterate_by_page(org_name, repo_name): if len(author_list) > 0: xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})" ylabel = "Contributor" - title = 'Active Contributor Leaderboard' + title = "Active Contributor Leaderboard" plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: @@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str: image_path (str): the path to the image to be uploaded """ url = "https://open.feishu.cn/open-apis/im/v1/images" - form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path + form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path multi_form = MultipartEncoder(form) headers = { - 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token + "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token } - headers['Content-Type'] = multi_form.content_type + headers["Content-Type"] = multi_form.content_type response = requests.request("POST", url, headers=headers, data=multi_form).json() - return response['data']['image_key'] + return response["data"]["image_key"] def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: @@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: app_id (str): Lark app id app_secret (str): Lark app secret """ - url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal' - data = {'app_id': app_id, 'app_secret': app_secret} + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + data = {"app_id": app_id, "app_secret": app_secret} response = requests.post(url, json=data).json() - return response['tenant_access_token'] + return response["tenant_access_token"] def send_image_to_lark(image_key: str, webhook_url: str) -> None: @@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str): requests.post(webhook_url, json=data) -if __name__ == '__main__': - GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] - CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png' - USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png' +if __name__ == "__main__": + GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] + CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png" + USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png" ORG_NAME = "hpcaitech" # get all open source repositories @@ -527,17 +532,19 @@ def send_message_to_lark(message: str, webhook_url: str): # generate images contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH) - engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH) + engagement_success = generate_user_engagement_leaderboard_image( + GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH + ) # upload images - APP_ID = os.environ['LARK_APP_ID'] - APP_SECRET = os.environ['LARK_APP_SECRET'] + APP_ID = os.environ["LARK_APP_ID"] + APP_SECRET = os.environ["LARK_APP_SECRET"] LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET) contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH) user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) # send message to lark - LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL'] + LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"] message = """本周的社区榜单出炉啦! 1. 开发贡献者榜单 2. 用户互动榜单 diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py index dc592e4c977b..7374481005ef 100644 --- a/.github/workflows/scripts/generate_release_draft.py +++ b/.github/workflows/scripts/generate_release_draft.py @@ -7,27 +7,27 @@ import requests -COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' -TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' +COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits" +TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags" def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--out', type=str, help='output path for the release draft', required=True) - parser.add_argument('--version', type=str, help='current version to release', required=True) + parser.add_argument("--out", type=str, help="output path for the release draft", required=True) + parser.add_argument("--version", type=str, help="current version to release", required=True) return parser.parse_args() def get_latest_tag_commit(headers=None): res = requests.get(url=TAGS_API, headers=headers) data = res.json() - commit_hash = data[0]['commit']['sha'] - version = data[0]['name'] + commit_hash = data[0]["commit"]["sha"] + version = data[0]["name"] return commit_hash, version def get_commit_info(commit_hash, headers=None): - api = f'{COMMIT_API}/{commit_hash}' + api = f"{COMMIT_API}/{commit_hash}" res = requests.get(url=api, headers=headers) return res.json() @@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None): results = [] while True: - api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' + api = f"{COMMIT_API}?since={since}&per_page=100&page={page}" resp = requests.get(url=api, headers=headers) data = resp.json() @@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None): def collate_release_info(commit_info_list): results = dict() - pattern = pattern = r'\[.*\]' + pattern = pattern = r"\[.*\]" for commit_info in commit_info_list: - author = commit_info['commit']['author']['name'] + author = commit_info["commit"]["author"]["name"] try: - author_url = commit_info['author']['url'] + author_url = commit_info["author"]["url"] except: # author can be None author_url = None - msg = commit_info['commit']['message'] + msg = commit_info["commit"]["message"] match = re.search(pattern, msg) if match: - tag = match.group().lstrip('[').rstrip(']').capitalize() + tag = match.group().lstrip("[").rstrip("]").capitalize() if tag not in results: results[tag] = [] results[tag].append((msg, author, author_url)) @@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info): for msg, author, author_url in v: # only keep the first line - msg = msg.split('\n')[0] + msg = msg.split("\n")[0] if author_url: - item = f'{msg} by [{author}]({author_url})\n' + item = f"{msg} by [{author}]({author_url})\n" else: - item = f'{msg} by {author}\n' - text.append(f'- {item}') + item = f"{msg} by {author}\n" + text.append(f"- {item}") - text.append('\n') + text.append("\n") # add full change log text.append( - f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') + f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}" + ) return text -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() - token = os.environ['GITHUB_API_TOKEN'] - headers = {'Authorization': token} + token = os.environ["GITHUB_API_TOKEN"] + headers = {"Authorization": token} # get previous release tag last_release_commit, last_version = get_latest_tag_commit(headers) last_release_commit_info = get_commit_info(last_release_commit, headers=headers) - last_release_date = last_release_commit_info['commit']['author']['date'] + last_release_date = last_release_commit_info["commit"]["author"]["date"] # get the commits since last release commit_info = get_all_commit_info(since=last_release_date, headers=headers) - commit_info = commit_info[:-1] # remove the release commit + commit_info = commit_info[:-1] # remove the release commit # collate into markdown release_info = collate_release_info(commit_info) markdown_text = generate_release_post_markdown(args.version, last_version, release_info) # write into a file - with open(args.out, 'w') as f: + with open(args.out, "w") as f: for line in markdown_text: f.write(line) diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py index a113327a786e..bc005d93c3f5 100644 --- a/.github/workflows/scripts/send_message_to_lark.py +++ b/.github/workflows/scripts/send_message_to_lark.py @@ -5,8 +5,8 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-m', '--message', type=str) - parser.add_argument('-u', '--url', type=str) + parser.add_argument("-m", "--message", type=str) + parser.add_argument("-u", "--url", type=str) return parser.parse_args() @@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url): requests.post(webhook_url, json=data) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() send_message_to_lark(args.message, args.url) diff --git a/.isort.cfg b/.isort.cfg index 090aa28e39f3..4f881c8b3dda 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,3 +3,4 @@ line_length = 120 multi_line_output=3 include_trailing_comma = true ignore_comments = true +profile = black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 725d266375ef..9871e1184462 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,31 @@ repos: + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + name: autoflake (python) + args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] + - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort name: sort all imports (python) - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 hooks: - - id: yapf - name: yapf formatter - args: ['--style=.style.yapf', '--parallel', '--in-place'] + - id: black + name: black formatter + args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v13.0.1 hooks: - id: clang-format name: clang formatter + types_or: [c++, c] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 05be0dc6a3a5..000000000000 --- a/.style.yapf +++ /dev/null @@ -1,5 +0,0 @@ -[style] -based_on_style = google -spaces_before_comment = 4 -split_before_logical_operator = true -column_limit = 120 diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 90471ed727b0..04f779821405 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int: def preprocess_batch(samples) -> dict: input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {'input_ids': input_ids, 'attention_mask': attention_mask} + return {"input_ids": input_ids, "attention_mask": attention_mask} def print_rank_0(*args, **kwargs) -> None: @@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None: B = 1024**3 M = 1024**2 K = 1024 - outputs = '' + outputs = "" for name, numel in model_dict.items(): - outputs += f'{name}: ' + outputs += f"{name}: " if numel >= B: - outputs += f'{numel / B:.2f} B\n' + outputs += f"{numel / B:.2f} B\n" elif numel >= M: - outputs += f'{numel / M:.2f} M\n' + outputs += f"{numel / M:.2f} M\n" elif numel >= K: - outputs += f'{numel / K:.2f} K\n' + outputs += f"{numel / K:.2f} K\n" else: - outputs += f'{numel}\n' + outputs += f"{numel}\n" print_rank_0(outputs) def get_gpt_config(model_name: str) -> OPTConfig: model_map = { - '125m': OPTConfig.from_pretrained('facebook/opt-125m'), - '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), - '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), - '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), - '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), - '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), - '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), - '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), - '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), - '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + "125m": OPTConfig.from_pretrained("facebook/opt-125m"), + "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"), + "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"), + "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"), + "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + "13b": OPTConfig.from_pretrained("facebook/opt-13b"), } try: return model_map[model_name] @@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig: def main(args): - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_gemini_cpu': - strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') - elif args.strategy == 'colossalai_zero1': - strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda') - elif args.strategy == 'colossalai_zero1_cpu': - strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif args.strategy == "colossalai_gemini_cpu": + strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") + elif args.strategy == "colossalai_zero1": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda") + elif args.strategy == "colossalai_zero1_cpu": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -103,90 +103,106 @@ def main(args): if args.use_kernels: from coati.kernels import convert_to_xformer_model - actor, critic, initial_model, reward_model = map(convert_to_xformer_model, - (actor, critic, initial_model, reward_model)) + + actor, critic, initial_model, reward_model = map( + convert_to_xformer_model, (actor, critic, initial_model, reward_model) + ) actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy) - print_model_numel({ - 'Actor': actor_numel, - 'Critic': critic_numel, - 'Initial model': initial_model_numel, - 'Reward model': reward_model_numel - }) - performance_evaluator = PerformanceEvaluator(actor_numel, - critic_numel, - initial_model_numel, - reward_model_numel, - enable_grad_checkpoint=False, - ignore_episodes=1) - - if args.strategy.startswith('colossalai'): + print_model_numel( + { + "Actor": actor_numel, + "Critic": critic_numel, + "Initial model": initial_model_numel, + "Reward model": reward_model_numel, + } + ) + performance_evaluator = PerformanceEvaluator( + actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1, + ) + + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) - dataloader = DataLoader(random_prompts, - batch_size=args.experience_batch_size, - shuffle=True, - collate_fn=preprocess_batch) - - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - ptx_coef=0, - train_batch_size=args.train_batch_size, - offload_inference_models=args.offload_inference_models, - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - use_cache=True, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - callbacks=[performance_evaluator]) - - trainer.fit(prompt_dataloader=dataloader, - pretrain_dataloader=None, - num_episodes=args.num_episodes, - num_update_steps=args.num_update_steps, - num_collect_steps=args.num_collect_steps) - - print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') - - -if __name__ == '__main__': + dataloader = DataLoader( + random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch + ) + + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + ptx_coef=0, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=[performance_evaluator], + ) + + trainer.fit( + prompt_dataloader=dataloader, + pretrain_dataloader=None, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) + + print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB") + + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='125m') - parser.add_argument('--critic_model', default='125m') - parser.add_argument('--strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', - 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' - ], - default='ddp') - parser.add_argument('--num_episodes', type=int, default=3) - parser.add_argument('--num_collect_steps', type=int, default=8) - parser.add_argument('--num_update_steps', type=int, default=1) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0) - parser.add_argument('--cuda_mem_frac', type=float, default=1.0) - parser.add_argument('--offload_inference_models', action='store_true', default=False) - parser.add_argument('--use_kernels', action='store_true', default=False) + parser.add_argument("--model", default="125m") + parser.add_argument("--critic_model", default="125m") + parser.add_argument( + "--strategy", + choices=[ + "ddp", + "colossalai_gemini", + "colossalai_gemini_cpu", + "colossalai_zero2", + "colossalai_zero2_cpu", + "colossalai_zero1", + "colossalai_zero1_cpu", + ], + default="ddp", + ) + parser.add_argument("--num_episodes", type=int, default=3) + parser.add_argument("--num_collect_steps", type=int, default=8) + parser.add_argument("--num_update_steps", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--cuda_mem_frac", type=float, default=1.0) + parser.add_argument("--offload_inference_models", action="store_true", default=False) + parser.add_argument("--use_kernels", action="store_true", default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py index 7fc990448805..98ace3869450 100644 --- a/applications/Chat/benchmarks/ray/1mmt_dummy.py +++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py @@ -22,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -36,22 +36,25 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) env_info_maker = { - 'local_rank': '0', - 'rank': '0', - 'world_size': '1', - 'master_port': maker_port, - 'master_addr': master_addr + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, } # configure tokenizer @@ -63,21 +66,27 @@ def model_fn(): critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() - reward_model = get_reward_model_from_args(args.critic_model, - config=critic_cfg).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, env_info=env_info_maker, @@ -97,15 +106,18 @@ def model_fn(): def trainer_model_fn(): actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() - critic = get_critic_from_args(args.critic_model, - config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) return actor, critic # configure Trainer trainer_refs = [ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( experience_maker_holder_name_list=[ - f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) + f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) ], strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), model_fn=trainer_model_fn, @@ -114,7 +126,8 @@ def trainer_model_fn(): buffer_limit=16, eval_performance=True, debug=args.debug, - ) for i, env_info_trainer in enumerate(env_info_trainers) + ) + for i, env_info_trainer in enumerate(env_info_trainers) ] dataset_size = args.experience_batch_size * 4 @@ -122,7 +135,7 @@ def trainer_model_fn(): def data_gen_fn(): input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) attn_mask = torch.ones_like(input_ids) - return {'input_ids': input_ids, 'attention_mask': attn_mask} + return {"input_ids": input_ids, "attention_mask": attn_mask} def build_dataloader(size): dataset = [data_gen_fn() for _ in range(size)] @@ -138,8 +151,10 @@ def build_dataloader(size): wait_tasks = [] wait_tasks.append( - experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), - num_steps=args.experience_steps)) + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) for trainer_ref in trainer_refs: @@ -148,31 +163,30 @@ def build_dataloader(size): ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py index ca1df22070fc..f8860f2979ee 100644 --- a/applications/Chat/benchmarks/ray/mmmt_dummy.py +++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py @@ -22,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -36,23 +36,29 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) - env_info_makers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_makers), - 'master_port': maker_port, - 'master_addr': master_addr - } for rank in range(args.num_makers)] + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] # configure tokenizer tokenizer = AutoTokenizer.from_pretrained(args.pretrain) @@ -63,14 +69,20 @@ def model_fn(): critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() - reward_model = get_reward_model_from_args(args.critic_model, - config=critic_cfg).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model @@ -79,7 +91,7 @@ def model_fn(): experience_holder_refs = [ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[ - f'trainer{x}' + f"trainer{x}" for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) ], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), @@ -103,8 +115,11 @@ def model_fn(): def trainer_model_fn(): actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() - critic = get_critic_from_args(args.critic_model, - config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) return actor, critic # configure Trainer @@ -130,7 +145,7 @@ def trainer_model_fn(): def data_gen_fn(): input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) attn_mask = torch.ones_like(input_ids) - return {'input_ids': input_ids, 'attention_mask': attn_mask} + return {"input_ids": input_ids, "attention_mask": attn_mask} def build_dataloader(size): dataset = [data_gen_fn() for _ in range(size)] @@ -147,43 +162,48 @@ def build_dataloader(size): for experience_holder_ref in experience_holder_refs: wait_tasks.append( - experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), - num_steps=args.experience_steps)) + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) - total_steps = args.experience_batch_size * args.experience_steps * \ - args.num_makers // (args.num_trainers * args.train_batch_size) + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) for trainer_ref in trainer_refs: wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--num_makers', type=int, default=1) - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py index bd4e5460d11e..599b57609775 100644 --- a/applications/Chat/coati/dataset/__init__.py +++ b/applications/Chat/coati/dataset/__init__.py @@ -4,7 +4,10 @@ from .utils import is_rank_0 __all__ = [ - 'RmStaticDataset', 'HhRlhfDataset', - 'SFTDataset', 'SupervisedDataset', - 'PromptDataset', 'is_rank_0', + "RmStaticDataset", + "HhRlhfDataset", + "SFTDataset", + "SupervisedDataset", + "PromptDataset", + "is_rank_0", ] diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py index 465fa867c7ab..f2180d96b0d3 100644 --- a/applications/Chat/coati/dataset/conversation.py +++ b/applications/Chat/coati/dataset/conversation.py @@ -49,7 +49,7 @@ def append_message(self, role, message): def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: @@ -57,12 +57,14 @@ def to_gradio_chatbot(self): return ret def copy(self): - return Conversation(system=self.system, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep) + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + ) def dict(self): return { @@ -70,7 +72,7 @@ def dict(self): "roles": self.roles, "messages": self.messages, "offset": self.offset, - "sep": self.sep + "sep": self.sep, } diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index 2c953fffa513..17120e6064b5 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -13,11 +13,13 @@ class PromptDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, - data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 96): + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + max_datasets_size: int = None, + max_length: int = 96, + ): super(PromptDataset, self).__init__() self.keyed_prompt = defaultdict(list) self.logger = get_dist_logger() @@ -30,11 +32,9 @@ def __init__(self, list_data_dict = list_data_dict[:max_datasets_size] instructions = [data_dict["instruction"] for data_dict in list_data_dict] - tokens = tokenizer(instructions, - return_tensors='pt', - max_length=max_length, - padding='max_length', - truncation=True) + tokens = tokenizer( + instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True + ) for k, tensor in tokens.items(): self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py index 3c4ec8b214bb..3afcd7b69238 100644 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -20,44 +20,31 @@ class RmStaticDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.end_token = tokenizer.eos_token \ - if special_token is None else special_token - - chosen = [ - data["prompt"] + data["chosen"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen = { - "input_ids": chosen_token["input_ids"], - "attention_mask": chosen_token["attention_mask"] - } - - reject = [ - data["prompt"] + data["rejected"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject = { - "input_ids": reject_token["input_ids"], - "attention_mask": reject_token["attention_mask"] - } + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ - self.reject["input_ids"][idx], self.reject["attention_mask"][idx] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) # Anthropic/hh-rlhf @@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.end_token = tokenizer.eos_token \ - if special_token is None else special_token - - chosen = [ - data["chosen"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen = { - "input_ids": chosen_token["input_ids"], - "attention_mask": chosen_token["attention_mask"] - } - - reject = [ - data["rejected"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject = { - "input_ids": reject_token["input_ids"], - "attention_mask": reject_token["attention_mask"] - } + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ - self.reject["input_ids"][idx], self.reject["attention_mask"][idx] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 2959d3fac81c..d6be09ca5cc9 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -16,10 +16,11 @@ from typing import Dict, Sequence, Tuple import torch +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from torch.utils.data import Dataset from tqdm import tqdm from transformers import PreTrainedTokenizer -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + from colossalai.logging import get_dist_logger from .utils import is_rank_0, jload @@ -28,32 +29,33 @@ IGNORE_INDEX = -100 PROMPT_DICT = { - "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), - "prompt_no_input": ("Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:"), + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), } -def _preprocess(sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess the data by tokenizing.""" sequences = [s + t for s, t in zip(sources, targets)] - sequences_token = tokenizer(sequences, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - sources_token = tokenizer(sources, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") + sequences_token = tokenizer( + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + sources_token = tokenizer( + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) labels = copy.deepcopy(sequences_token["input_ids"]) for i in range(labels.shape[0]): @@ -64,23 +66,24 @@ def _preprocess(sources: Sequence[str], labels[i][:source_len] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| - labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX + labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX else: raise RuntimeError() return sequences_token["input_ids"], labels, sequences_token["attention_mask"] -def _preprocess_chatglm(sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _preprocess_chatglm( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preprocess the data by tokenizing. None for attention mask, ChatGLM will calculate attention mask according to input ids """ - + labels = [] input_ids = [] for source, target in zip(sources, targets): @@ -90,16 +93,16 @@ def _preprocess_chatglm(sources: Sequence[str], # truncate sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id] truncate_length = max(0, len(input_id) - max_length) - input_id = input_id[truncate_length: ] + input_id = input_id[truncate_length:] if truncate_length == len(source_id) + 1: - input_id = sp_token_list + input_id[1: ] + input_id = sp_token_list + input_id[1:] elif truncate_length > len(source_id) + 1: - input_id = sp_token_list + input_id[2: ] - + input_id = sp_token_list + input_id[2:] + context_length = input_id.index(tokenizer.bos_token_id) mask_position = context_length - 1 - label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:] - + label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :] + pad_len = max_length - len(input_id) input_id = input_id + [tokenizer.pad_token_id] * pad_len input_ids.append(input_id) @@ -117,25 +120,18 @@ class SFTDataset(Dataset): max_length: max length of input """ - def __init__(self, - dataset: Dict, - tokenizer: PreTrainedTokenizer, - max_length: int = 512 - ) -> None: + def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None: super().__init__() self.input_ids = [] sources = [data["prompt"] for data in dataset] - targets = [ - data["completion"] + tokenizer.eos_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] + targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = \ - _preprocess_chatglm(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) else: - self.input_ids, self.labels, self.attention_mask = \ - _preprocess(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) def __len__(self): length = self.input_ids.shape[0] @@ -143,22 +139,17 @@ def __len__(self): def __getitem__(self, idx): if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx], - attention_mask=self.attention_mask[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) else: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, - data_path: str, - tokenizer: PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 512): + def __init__( + self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512 + ): super().__init__() logger.info("Loading data...") list_data_dict = jload(data_path) @@ -174,18 +165,15 @@ def __init__(self, prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) for example in list_data_dict ] - targets = [ - example['output'] + tokenizer.eos_token - for example in list_data_dict - ] + targets = [example["output"] + tokenizer.eos_token for example in list_data_dict] logger.info("Tokenizing inputs... This may take some time...") if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = \ - _preprocess_chatglm(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) else: - self.input_ids, self.labels, self.attention_mask = \ - _preprocess(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) def __len__(self): length = self.input_ids.shape[0] @@ -193,9 +181,6 @@ def __len__(self): def __getitem__(self, idx): if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx], - attention_mask=self.attention_mask[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) else: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py index c0188dc4a471..f2a48d0a3b20 100644 --- a/applications/Chat/coati/experience_buffer/__init__.py +++ b/applications/Chat/coati/experience_buffer/__init__.py @@ -1,4 +1,4 @@ from .base import ExperienceBuffer from .naive import NaiveExperienceBuffer -__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer'] +__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"] diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py index 9ccdc935d506..7047785308f3 100644 --- a/applications/Chat/coati/experience_buffer/base.py +++ b/applications/Chat/coati/experience_buffer/base.py @@ -7,9 +7,9 @@ class ExperienceBuffer(ABC): """Experience buffer base class. It stores experience. - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. """ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py index bd5213b38993..acc0fbe88ab4 100644 --- a/applications/Chat/coati/experience_buffer/naive.py +++ b/applications/Chat/coati/experience_buffer/naive.py @@ -11,23 +11,23 @@ class NaiveExperienceBuffer(ExperienceBuffer): """Naive experience buffer class. It stores experience. - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. - cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. """ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: super().__init__(sample_batch_size, limit) self.cpu_offload = cpu_offload - self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') + self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") # TODO(ver217): add prefetch self.items: List[BufferItem] = [] @torch.no_grad() def append(self, experience: Experience) -> None: if self.cpu_offload: - experience.to_device(torch.device('cpu')) + experience.to_device(torch.device("cpu")) items = split_experience_batch(experience) self.items.extend(items) if self.limit > 0: diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py index c2a34212e2f4..baedbebd184f 100644 --- a/applications/Chat/coati/experience_buffer/utils.py +++ b/applications/Chat/coati/experience_buffer/utils.py @@ -21,6 +21,7 @@ class BufferItem: "A" is the number of actions. """ + sequences: torch.Tensor action_log_probs: torch.Tensor values: torch.Tensor @@ -33,8 +34,7 @@ class BufferItem: def split_experience_batch(experience: Experience) -> List[BufferItem]: batch_size = experience.sequences.size(0) batch_kwargs = [{} for _ in range(batch_size)] - keys = ('sequences', 'action_log_probs', 'values', - 'reward', 'advantages', 'attention_mask', 'action_mask') + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") for key in keys: value = getattr(experience, key) if isinstance(value, torch.Tensor): @@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: return items -def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: - assert side in ('left', 'right') +def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor: + assert side in ("left", "right") max_len = max(seq.size(0) for seq in sequences) padded_sequences = [] for seq in sequences: pad_len = max_len - seq.size(0) - padding = (pad_len, 0) if side == 'left' else (0, pad_len) + padding = (pad_len, 0) if side == "left" else (0, pad_len) padded_sequences.append(F.pad(seq, padding)) return torch.stack(padded_sequences, dim=0) def make_experience_batch(items: List[BufferItem]) -> Experience: kwargs = {} - to_pad_keys = set(('action_log_probs', 'action_mask')) - keys = ('sequences', 'action_log_probs', 'values', - 'reward', 'advantages', 'attention_mask', 'action_mask') + to_pad_keys = set(("action_log_probs", "action_mask")) + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") for key in keys: vals = [getattr(item, key) for item in items] if key in to_pad_keys: diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py index 39ca7576b227..06452292e77c 100644 --- a/applications/Chat/coati/experience_maker/__init__.py +++ b/applications/Chat/coati/experience_maker/__init__.py @@ -1,4 +1,4 @@ from .base import Experience, ExperienceMaker from .naive import NaiveExperienceMaker -__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] +__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"] diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index b4646f282f0c..727f0a4a52e8 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -24,6 +24,7 @@ class Experience: "A" is the number of actions. """ + sequences: torch.Tensor action_log_probs: torch.Tensor values: torch.Tensor @@ -58,13 +59,9 @@ def pin_memory(self): class ExperienceMaker(ABC): - - def __init__(self, - actor: Actor, - critic: nn.Module, - reward_model: nn.Module, - initial_model: Actor, - kl_coef: float = 0.1) -> None: + def __init__( + self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1 + ) -> None: super().__init__() self.actor = actor self.critic = critic diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 496f8ab445fc..30dfd8e0b9bc 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -23,22 +23,21 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie # calculate auxiliary tensors attention_mask = None - pad_token_id = generate_kwargs.get('pad_token_id', None) + pad_token_id = generate_kwargs.get("pad_token_id", None) if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id)\ - .to(dtype=torch.long, device=sequences.device) + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) input_len = input_ids.size(1) - eos_token_id = generate_kwargs.get('eos_token_id', None) + eos_token_id = generate_kwargs.get("eos_token_id", None) if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: # left padding may be applied, only mask action action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask[:, :input_len] = False action_mask = action_mask[:, 1:] - action_mask = action_mask[:, -(sequences.size(1) - input_len):] + action_mask = action_mask[:, -(sequences.size(1) - input_len) :] num_actions = action_mask.size(1) actor_output = self.actor(sequences, attention_mask) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py index 230eedf7ecba..96d40c7c4709 100644 --- a/applications/Chat/coati/kernels/__init__.py +++ b/applications/Chat/coati/kernels/__init__.py @@ -1,6 +1,6 @@ from .wrapper import convert_to_xformer_model, recover_from_xformer_model __all__ = [ - 'convert_to_xformer_model', - 'recover_from_xformer_model', + "convert_to_xformer_model", + "recover_from_xformer_model", ] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py index e99f9c2247d1..d1eb139187f3 100644 --- a/applications/Chat/coati/kernels/opt_attn.py +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -21,11 +21,12 @@ def forward( output_attentions: bool = False, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: if not self.training: - return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, - output_attentions) + return super().forward( + hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions + ) """Input shape: Batch x Time x Channel""" - assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' - assert not output_attentions, 'Xformers attention does not support output_attentions' + assert layer_head_mask is None, "Xformers attention does not support layer_head_mask" + assert not output_attentions, "Xformers attention does not support output_attentions" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -69,12 +70,14 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xops.memory_efficient_attention(query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask(), - p=self.dropout if self.training else 0.0, - scale=self.scaling) + attn_output = xops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling, + ) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py index 0a296a863756..ad4a525b4af2 100644 --- a/applications/Chat/coati/models/__init__.py +++ b/applications/Chat/coati/models/__init__.py @@ -3,6 +3,13 @@ from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss __all__ = [ - 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss', - 'LoRAModule', 'convert_to_lora_module' + "Actor", + "Critic", + "RewardModel", + "PolicyLoss", + "ValueLoss", + "LogSigLoss", + "LogExpLoss", + "LoRAModule", + "convert_to_lora_module", ] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py index c5f748a0c85a..5c9905bb2224 100644 --- a/applications/Chat/coati/models/base/__init__.py +++ b/applications/Chat/coati/models/base/__init__.py @@ -9,7 +9,7 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: """Get the base model of our wrapper classes. - For Actor, Critic and RewardModel, return ``model.model``, + For Actor, Critic and RewardModel, return ``model.model``, it's usually a ``transformers.PreTrainedModel``. Args: @@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: Returns: nn.Module: the base model """ - assert isinstance(model, (Actor, Critic, RewardModel)), \ - f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.' + assert isinstance( + model, (Actor, Critic, RewardModel) + ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first." return model.model -__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] +__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 6842f81d9b87..979f9318be50 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -16,18 +16,17 @@ class Actor(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() def forward( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, # HACK: `generate` method may pass more kwargs + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, # HACK: `generate` method may pass more kwargs ) -> torch.Tensor: - """Returns model output. - """ + """Returns model output.""" output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py index e68a743a7762..54ab7fa47d48 100644 --- a/applications/Chat/coati/models/base/critic.py +++ b/applications/Chat/coati/models/base/critic.py @@ -23,22 +23,23 @@ def __init__( model: nn.Module, value_head: nn.Module, lora_rank: int = 0, - lora_train_bias: str = 'none', + lora_train_bias: str = "none", use_action_mask: bool = False, ) -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head self.use_action_mask = use_action_mask self.convert_to_lora() - def forward(self, - sequences: torch.LongTensor, - action_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] + last_hidden_states = outputs["last_hidden_state"] values = self.value_head(last_hidden_states).squeeze(-1) diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py index ce8c0a1d3568..1a70c6cc12bb 100644 --- a/applications/Chat/coati/models/base/reward_model.py +++ b/applications/Chat/coati/models/base/reward_model.py @@ -17,11 +17,13 @@ class RewardModel(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - model: nn.Module, - value_head: Optional[nn.Module] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + model: nn.Module, + value_head: Optional[nn.Module] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() @@ -35,7 +37,7 @@ def __init__(self, def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] + last_hidden_states = outputs["last_hidden_state"] values = self.value_head(last_hidden_states)[:, :-1] - value = values.mean(dim=1).squeeze(1) # ensure shape is (B) + value = values.mean(dim=1).squeeze(1) # ensure shape is (B) return value diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py index d0e7f7b1ef94..7af199a67d3b 100644 --- a/applications/Chat/coati/models/bloom/__init__.py +++ b/applications/Chat/coati/models/bloom/__init__.py @@ -2,4 +2,4 @@ from .bloom_critic import BLOOMCritic from .bloom_rm import BLOOMRM -__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] +__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"] diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py index d7577f096493..73855a2245e7 100644 --- a/applications/Chat/coati/models/bloom/bloom_actor.py +++ b/applications/Chat/coati/models/bloom/bloom_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomForCausalLM from ..base import Actor @@ -18,12 +17,14 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py index a3716ca94138..b2d838f7ffc5 100644 --- a/applications/Chat/coati/models/bloom/bloom_critic.py +++ b/applications/Chat/coati/models/bloom/bloom_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import Critic @@ -18,12 +17,14 @@ class BLOOMCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py index e6ca9b1d4851..c09457ddc8c7 100644 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import RewardModel @@ -17,11 +17,13 @@ class BLOOMRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py index 373f19553fdc..5956f5a8e91b 100644 --- a/applications/Chat/coati/models/chatglm/__init__.py +++ b/applications/Chat/coati/models/chatglm/__init__.py @@ -1,3 +1,3 @@ from .chatglm_actor import ChatGLMActor -__all__ = ['ChatGLMActor'] \ No newline at end of file +__all__ = ["ChatGLMActor"] diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py index c35d994e9319..00a61561ee47 100644 --- a/applications/Chat/coati/models/chatglm/chatglm_actor.py +++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py @@ -1,11 +1,9 @@ from typing import Optional -import torch +from ..base import Actor from .configuration_chatglm import ChatGLMConfig from .modeling_chatglm import ChatGLMForConditionalGeneration -from ..base import Actor - class ChatGLMActor(Actor): """ @@ -19,10 +17,9 @@ class ChatGLMActor(Actor): do not support lora for now. """ - def __init__(self, - pretrained: str = None, - config: Optional[ChatGLMConfig] = None, - checkpoint: bool = False) -> None: + def __init__( + self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False + ) -> None: if pretrained is not None: model = ChatGLMForConditionalGeneration.from_pretrained(pretrained) elif config is not None: @@ -31,4 +28,4 @@ def __init__(self, model = ChatGLMForConditionalGeneration(ChatGLMConfig()) if checkpoint: model.gradient_checkpointing_enable() - super().__init__(model, lora_rank=0, lora_train_bias='none') + super().__init__(model, lora_rank=0, lora_train_bias="none") diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py index f7717f7e68b6..221ef044b470 100644 --- a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py +++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py @@ -2,15 +2,14 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py """ """Tokenization classes for ChatGLM.""" -from typing import List, Optional, Union import os +from typing import Dict, List, Optional, Union -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging, PaddingStrategy -from transformers.tokenization_utils_base import EncodedInput, BatchEncoding -from typing import Dict -import sentencepiece as spm import numpy as np +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging logger = logging.get_logger(__name__) @@ -52,11 +51,11 @@ def __len__(self): class SPTokenizer: def __init__( - self, - vocab_file, - num_image_tokens=20000, - max_blank_length=80, - byte_fallback=True, + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, ): assert vocab_file is not None self.vocab_file = vocab_file @@ -100,9 +99,7 @@ def _preprocess(self, text: str, linebreak=True, whitespaces=True): text = self._encode_whitespaces(text, max_len=self.max_blank_length) return text - def encode( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[int]: + def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]: """ @param text: Text to encode. @param linebreak: Whether to encode newline (\n) in text. @@ -136,9 +133,7 @@ def decode_tokens(self, tokens: List[str]) -> str: text = self.postprocess(text) return text - def tokenize( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[str]: + def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]: """ @param text: Text to encode. @param linebreak: Whether to encode newline (\n) in text. @@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( - self, - vocab_file, - do_lower_case=False, - remove_space=False, - bos_token='', - eos_token='', - end_token='', - mask_token='[MASK]', - gmask_token='[gMASK]', - padding_side="left", - pad_token="", - unk_token="", - num_image_tokens=20000, - **kwargs + self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token="", + eos_token="", + end_token="", + mask_token="[MASK]", + gmask_token="[gMASK]", + padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, + **kwargs, ) -> None: super().__init__( do_lower_case=do_lower_case, @@ -208,7 +203,7 @@ def __init__( pad_token=pad_token, unk_token=unk_token, num_image_tokens=num_image_tokens, - **kwargs + **kwargs, ) self.do_lower_case = do_lower_case @@ -243,11 +238,11 @@ def end_token_id(self) -> Optional[int]: @property def vocab_size(self): - """ Returns vocab size """ + """Returns vocab size""" return self.sp_tokenizer.num_tokens def get_vocab(self): - """ Returns vocab as a dict """ + """Returns vocab as a dict""" vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab @@ -264,7 +259,7 @@ def preprocess_text(self, inputs): return outputs def _tokenize(self, text, **kwargs): - """ Returns a tokenized string. """ + """Returns a tokenized string.""" text = self.preprocess_text(text) seq = self.sp_tokenizer.tokenize(text) @@ -274,11 +269,7 @@ def _tokenize(self, text, **kwargs): def convert_tokens_to_string(self, tokens: List[str]) -> str: return self.sp_tokenizer.decode_tokens(tokens) - def _decode( - self, - token_ids: Union[int, List[int]], - **kwargs - ) -> str: + def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: if isinstance(token_ids, int): token_ids = [token_ids] if len(token_ids) == 0: @@ -288,7 +279,7 @@ def _decode( return super()._decode(token_ids, **kwargs) def _convert_token_to_id(self, token): - """ Converts a token (str) in an id using the vocab. """ + """Converts a token (str) in an id using the vocab.""" return self.sp_tokenizer[token] def _convert_id_to_token(self, index): @@ -309,13 +300,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None): `Tuple(str)`: Paths to the files saved. """ if os.path.isdir(save_directory): - vocab_file = os.path.join( - save_directory, self.vocab_files_names["vocab_file"] - ) + vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) else: vocab_file = save_directory - with open(self.vocab_file, 'rb') as fin: + with open(self.vocab_file, "rb") as fin: proto_str = fin.read() with open(vocab_file, "wb") as writer: @@ -324,7 +313,7 @@ def save_vocabulary(self, save_directory, filename_prefix=None): return (vocab_file,) def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and @@ -343,19 +332,19 @@ def build_inputs_with_special_tokens( `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ gmask_id = self.sp_tokenizer[self.gmask_token] - eos_id = self.sp_tokenizer[self.eos_token] + self.sp_tokenizer[self.eos_token] token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: token_ids_0 = token_ids_0 + token_ids_1 return token_ids_0 def _pad( - self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) @@ -421,17 +410,23 @@ def _pad( mask_position = required_input.index(mask_token) position_ids[context_length:] = mask_position block_position_ids = np.concatenate( - [np.zeros(context_length, dtype=np.int64), - np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + [ + np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64), + ] + ) encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) if needs_to_be_padded: difference = max_length - len(required_input) if "attention_mask" in encoded_inputs: - encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], - pad_width=[(0, 0), (difference, 0), (difference, 0)], - mode='constant', constant_values=True) + encoded_inputs["attention_mask"] = np.pad( + encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode="constant", + constant_values=True, + ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ "token_type_ids" @@ -439,8 +434,9 @@ def _pad( if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] if "position_ids" in encoded_inputs: - encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], - pad_width=[(0, 0), (difference, 0)]) + encoded_inputs["position_ids"] = np.pad( + encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)] + ) encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input - return encoded_inputs \ No newline at end of file + return encoded_inputs diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py index d0e3f6cc63d7..a6d2ccd18715 100644 --- a/applications/Chat/coati/models/chatglm/configuration_chatglm.py +++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py @@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig): >>> # Accessing the model configuration >>> configuration = model.config - ``` -""" + ```""" model_type = "chatglm" def __init__( - self, - vocab_size=130528, - hidden_size=4096, - num_layers=28, - num_attention_heads=32, - layernorm_epsilon=1e-5, - use_cache=True, - bos_token_id=130004, - eos_token_id=130005, - mask_token_id=130000, - gmask_token_id=130001, - pad_token_id=3, - max_sequence_length=2048, - inner_hidden_size=16384, - position_encoding_2d=True, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs + self, + vocab_size=130528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=True, + bos_token_id=130004, + eos_token_id=130005, + mask_token_id=130000, + gmask_token_id=130001, + pad_token_id=3, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, ): self.num_layers = num_layers self.vocab_size = vocab_size @@ -99,9 +98,4 @@ def __init__( self.pre_seq_len = pre_seq_len self.prefix_projection = prefix_projection - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs - ) \ No newline at end of file + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py index 77e7d0d8ea09..d1d15c68ffd8 100644 --- a/applications/Chat/coati/models/chatglm/modeling_chatglm.py +++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py @@ -4,41 +4,40 @@ """ PyTorch ChatGLM model. """ -import math import copy +import math import os -import warnings import re import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint import torch.nn.functional as F +import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList from transformers.modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) from .configuration_chatglm import ChatGLMConfig # flags required to enable jit fusion kernels -if sys.platform != 'darwin': +if sys.platform != "darwin": torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) @@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name ): logger.info(f"Skipping {'/'.join(name)}") continue @@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): array = np.transpose(array) try: assert ( - pointer.shape == array.shape + pointer.shape == array.shape ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" except AssertionError as e: e.args += (pointer.shape, array.shape) @@ -153,7 +152,7 @@ def __init__(self, config): self.trans = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2), ) else: self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) @@ -170,8 +169,7 @@ def forward(self, prefix: torch.Tensor): @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) def gelu(x): @@ -181,21 +179,22 @@ def gelu(x): class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False): super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: - self.register_buffer('inv_freq', inv_freq) + self.register_buffer("inv_freq", inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): pass def forward(self, x, seq_dim=1, seq_len=None): @@ -204,7 +203,7 @@ def forward(self, x, seq_dim=1, seq_len=None): if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: @@ -230,30 +229,31 @@ def _apply(self, fn): def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @torch.jit.script def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ - F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding( + position_id, sin.squeeze(1) + ).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, ): if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] @@ -285,7 +285,9 @@ def attention_fn( key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) matmul_result = torch.zeros( - 1, 1, 1, + 1, + 1, + 1, dtype=query_layer.dtype, device=query_layer.device, ) @@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs): class SelfAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + def __init__( + self, + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=None, + bias=True, + params_dtype=torch.float, + position_encoding_2d=True, + empty_init=True, + ): if empty_init: init_method = skip_init else: @@ -410,8 +420,7 @@ def attention_mask_func(attention_scores, attention_mask): attention_scores.masked_fill_(attention_mask, -10000.0) return attention_scores - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): + def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -431,14 +440,14 @@ def split_tensor_along_last_dim(self, tensor, num_partitions, return tensor_list def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, ): """ hidden_states: [seq_len, batch, hidden_size] @@ -462,8 +471,10 @@ def forward( q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ - position_ids[:, 1, :].transpose(0, 1).contiguous() + position_ids, block_position_ids = ( + position_ids[:, 0, :].transpose(0, 1).contiguous(), + position_ids[:, 1, :].transpose(0, 1).contiguous(), + ) q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) @@ -484,7 +495,7 @@ def forward( hidden_size_per_partition=self.hidden_size_per_partition, layer_id=layer_id, layer_past=layer_past, - use_cache=use_cache + use_cache=use_cache, ) output = self.dense(context_layer) @@ -509,8 +520,16 @@ def forward(self, x): class GLU(torch.nn.Module): - def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + def __init__( + self, + hidden_size, + inner_hidden_size=None, + layer_id=None, + bias=True, + activation_func=gelu, + params_dtype=torch.float, + empty_init=True, + ): super(GLU, self).__init__() if empty_init: init_method = skip_init @@ -557,19 +576,19 @@ def forward(self, hidden_states): class GLMBlock(torch.nn.Module): def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True, ): super(GLMBlock, self).__init__() # Set output layer initialization if not provided. @@ -590,7 +609,7 @@ def __init__( bias=use_bias, params_dtype=params_dtype, position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init + empty_init=empty_init, ) # Layernorm on the input data. @@ -605,18 +624,18 @@ def __init__( bias=use_bias, layer_id=layer_id, params_dtype=params_dtype, - empty_init=empty_init + empty_init=empty_init, ) def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, ): """ hidden_states: [seq_len, batch, hidden_size] @@ -635,7 +654,7 @@ def forward( layer_id=layer_id, layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions + output_attentions=output_attentions, ) attention_output = attention_outputs[0] @@ -702,10 +721,15 @@ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) for i, context_length in enumerate(context_lengths): position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] + block_position_ids = [ + torch.cat( + ( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1, + ) + ) + for context_length in context_lengths + ] block_position_ids = torch.stack(block_position_ids, dim=0) position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: @@ -823,9 +847,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True): self.prefix_projection = config.prefix_projection self.word_embeddings = init_method( - torch.nn.Embedding, - num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, - dtype=self.params_dtype + torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype ) self.gradient_checkpointing = False @@ -841,12 +863,10 @@ def get_layer(layer_id): use_bias=True, params_dtype=self.params_dtype, position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init + empty_init=empty_init, ) - self.layers = torch.nn.ModuleList( - [get_layer(layer_id) for layer_id in range(self.num_layers)] - ) + self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)]) # Final layer norm before output. self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) @@ -876,7 +896,7 @@ def get_prompt(self, batch_size, device, dtype=torch.half): self.pre_seq_len, self.num_layers * 2, self.num_attention_heads, - self.hidden_size // self.num_attention_heads + self.hidden_size // self.num_attention_heads, ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) @@ -891,18 +911,17 @@ def get_prompt(self, batch_size, device, dtype=torch.half): config_class=_CONFIG_FOR_DOC, ) def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -931,17 +950,14 @@ def forward( if past_key_values is None: if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, - dtype=inputs_embeds.dtype) + past_key_values = self.get_prompt( + batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype + ) else: past_key_values = tuple([None] * len(self.layers)) if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - + attention_mask = self.get_masks(input_ids, device=input_ids.device) if position_ids is None: MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id @@ -955,15 +971,13 @@ def forward( use_gmasks.append(use_gmask) position_ids = self.get_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - use_gmasks=use_gmasks + input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks ) if self.pre_seq_len is not None and attention_mask is not None: prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) + attention_mask.device + ) prefix_attention_mask = (prefix_attention_mask < 0.5).bool() attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) @@ -980,7 +994,6 @@ def forward( attention_mask = attention_mask.to(hidden_states.device) for i, layer in enumerate(self.layers): - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_past = past_key_values[i] @@ -994,7 +1007,7 @@ def forward( torch.tensor(i), layer_past, use_cache, - output_attentions + output_attentions, ) else: layer_ret = layer( @@ -1004,7 +1017,7 @@ def forward( layer_id=torch.tensor(i), layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions + output_attentions=output_attentions, ) hidden_states = layer_ret[0] @@ -1049,13 +1062,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True): self.transformer = ChatGLMModel(config, empty_init=empty_init) - self.lm_head = init_method( - nn.Linear, - config.hidden_size, - config.vocab_size, - bias=False, - dtype=torch.half - ) + self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half) self.config = config @@ -1087,32 +1094,29 @@ def _update_model_kwargs_for_generation( attention_mask = model_kwargs["attention_mask"] if attention_mask is not None and attention_mask.dtype == torch.bool: attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3 + ) new_attention_mask = attention_mask[:, :, -1:].clone() new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) + model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2) # update position ids if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] new_position_id = position_ids[..., -1:].clone() new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) return model_kwargs def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, ) -> dict: batch_size, seq_length = input_ids.shape MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id @@ -1137,11 +1141,17 @@ def prepare_inputs_for_generation( context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] if self.position_encoding_2d: position_ids = torch.tensor( - [[mask_position, seq_length - context_length] for mask_position, context_length in - zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + [ + [mask_position, seq_length - context_length] + for mask_position, context_length in zip(mask_positions, context_lengths) + ], + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(-1) else: - position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - device=input_ids.device).unsqueeze(-1) + position_ids = torch.tensor( + [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device + ).unsqueeze(-1) if past is None: past = past_key_values @@ -1149,44 +1159,38 @@ def prepare_inputs_for_generation( "input_ids": last_token, "past_key_values": past, "position_ids": position_ids, - "attention_mask": attention_mask + "attention_mask": attention_mask, } else: if attention_mask is not None and attention_mask.dtype != torch.bool: logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") attention_mask = None if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) + attention_mask = self.get_masks(input_ids, device=input_ids.device) if position_ids is None: position_ids = self.get_position_ids( - input_ids, - device=input_ids.device, - mask_positions=mask_positions, - use_gmasks=use_gmasks + input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks ) return { "input_ids": input_ids, "past_key_values": past, "position_ids": position_ids, - "attention_mask": attention_mask + "attention_mask": attention_mask, } def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1235,7 +1239,7 @@ def forward( @staticmethod def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or @@ -1268,15 +1272,33 @@ def process_response(self, response): return response @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } if not history: prompt = query else: @@ -1287,22 +1309,38 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] return response, history @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } if not history: prompt = query else: @@ -1313,7 +1351,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) new_history = history + [(query, response)] @@ -1321,13 +1359,13 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No @torch.no_grad() def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, ): batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index de0d63f95f50..e3afac88c7a7 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -16,9 +16,9 @@ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def _prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def _prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def _sample(model: Actor, - input_ids: torch.Tensor, - max_length: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def _sample( + model: Actor, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: if input_ids.size(1) >= max_length: return input_ids @@ -56,11 +58,12 @@ def _sample(model: Actor, unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(input_ids.size(1), max_length): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \ - if prepare_inputs_fn is not None else {'input_ids': input_ids} + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs["logits"][:, -1, :] # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -90,20 +93,22 @@ def _sample(model: Actor, @torch.no_grad() -def generate(model: Actor, - input_ids: torch.Tensor, - max_length: int, - num_beams: int = 1, - do_sample: bool = True, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def generate( + model: Actor, + input_ids: torch.Tensor, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: """Generate token sequence. The returned sequence is input_ids + generated_tokens. Args: @@ -121,26 +126,28 @@ def generate(model: Actor, prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. """ - is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) - is_sample_gen_mode = ((num_beams == 1) and do_sample is True) - is_beam_gen_mode = ((num_beams > 1) and do_sample is False) + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False if is_greedy_gen_mode: # run greedy search raise NotImplementedError elif is_sample_gen_mode: # run sample - return _sample(model, - input_ids, - max_length, - early_stopping=early_stopping, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - top_k=top_k, - top_p=top_p, - temperature=temperature, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs) + return _sample( + model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs, + ) elif is_beam_gen_mode: raise NotImplementedError else: diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py index 63dc5ab0f5ea..823cf4a75e0d 100644 --- a/applications/Chat/coati/models/gpt/__init__.py +++ b/applications/Chat/coati/models/gpt/__init__.py @@ -2,4 +2,4 @@ from .gpt_critic import GPTCritic from .gpt_rm import GPTRM -__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] +__all__ = ["GPTActor", "GPTCritic", "GPTRM"] diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py index ae9d669f1f56..a7e4b9bc3e22 100644 --- a/applications/Chat/coati/models/gpt/gpt_actor.py +++ b/applications/Chat/coati/models/gpt/gpt_actor.py @@ -18,13 +18,15 @@ class GPTActor(Actor): lora_train_bias (str): Bias training strategy for the LoRa layer. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2LMHeadModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py index 01e1cd10ef57..22ab36dea276 100644 --- a/applications/Chat/coati/models/gpt/gpt_critic.py +++ b/applications/Chat/coati/models/gpt/gpt_critic.py @@ -18,12 +18,14 @@ class GPTCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py index e52a5a14c1da..8edfc4008466 100644 --- a/applications/Chat/coati/models/gpt/gpt_rm.py +++ b/applications/Chat/coati/models/gpt/gpt_rm.py @@ -18,11 +18,13 @@ class GPTRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py index 9b2a024afdb2..c87d732538a9 100644 --- a/applications/Chat/coati/models/llama/__init__.py +++ b/applications/Chat/coati/models/llama/__init__.py @@ -2,4 +2,4 @@ from .llama_critic import LlamaCritic from .llama_rm import LlamaRM -__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] +__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"] diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py index 2c7adb390d8b..f1d9406835ca 100644 --- a/applications/Chat/coati/models/llama/llama_actor.py +++ b/applications/Chat/coati/models/llama/llama_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaForCausalLM from ..base import Actor @@ -18,13 +17,14 @@ class LlamaActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index a67e5de5def6..000dce17ccf0 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -17,13 +17,14 @@ class LlamaCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py index d6b62922686e..43bc9e638dc7 100644 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +from transformers import LlamaConfig, LlamaModel from ..base import RewardModel @@ -17,12 +17,13 @@ class LlamaRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index f1597da540a7..2114913e107b 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -8,8 +8,7 @@ class LoraLinear(lora.LoRALayer, nn.Module): - """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. - """ + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" def __init__( self, @@ -17,16 +16,14 @@ def __init__( bias: Optional[nn.Parameter], r: int = 0, lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, ): nn.Module.__init__(self) - lora.LoRALayer.__init__(self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) + lora.LoRALayer.__init__( + self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights + ) self.weight = weight self.bias = bias @@ -47,13 +44,12 @@ def __init__( self.weight.data = self.weight.data.T def reset_parameters(self): - if hasattr(self, 'lora_A'): + if hasattr(self, "lora_A"): # Initialize A with the default values for nn.Linear and set B to zero. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): - def T(w): return w.T if self.fan_in_fan_out else w @@ -71,7 +67,6 @@ def T(w): self.merged = False def eval(self): - def T(w): return w.T if self.fan_in_fan_out else w @@ -80,12 +75,11 @@ def T(w): # Merge the weights and mark it if self.r > 0: self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, 'lora_A') - delattr(self, 'lora_B') + delattr(self, "lora_A") + delattr(self, "lora_B") self.merged = True def forward(self, x: torch.Tensor): - def T(w): return w.T if self.fan_in_fan_out else w @@ -99,7 +93,9 @@ def T(w): def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: - assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' + assert ( + lora_rank <= linear.in_features + ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) return lora_linear @@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: _convert_to_lora_recursively(child, lora_rank) -def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: +def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module: """Convert a torch.nn.Module to a LoRA module. Args: @@ -140,7 +136,7 @@ class LoRAModule(nn.Module): Defaults to 'none'. """ - def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__() self.lora_rank = lora_rank self.lora_train_bias = lora_train_bias diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 05a0b4821797..4ad4f4dcd275 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -31,11 +31,13 @@ def __init__(self, clip_eps: float = 0.2) -> None: super().__init__() self.clip_eps = clip_eps - def forward(self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages @@ -55,14 +57,16 @@ def __init__(self, clip_eps: float = 0.4) -> None: super().__init__() self.clip_eps = clip_eps - def forward(self, - values: torch.Tensor, - old_values: torch.Tensor, - reward: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + reward: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) - surr1 = (values_clipped - reward)**2 - surr2 = (values - reward)**2 + surr1 = (values_clipped - reward) ** 2 + surr2 = (values - reward) ** 2 loss = torch.max(surr1, surr2) loss = loss.mean() return 0.5 * loss diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py index 334f4df0032a..e37d6e45c8fc 100644 --- a/applications/Chat/coati/models/opt/__init__.py +++ b/applications/Chat/coati/models/opt/__init__.py @@ -2,4 +2,4 @@ from .opt_critic import OPTCritic from .opt_rm import OPTRM -__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] +__all__ = ["OPTActor", "OPTCritic", "OPTRM"] diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py index c14e4377ffb2..cd8908e13fb8 100644 --- a/applications/Chat/coati/models/opt/opt_actor.py +++ b/applications/Chat/coati/models/opt/opt_actor.py @@ -18,12 +18,14 @@ class OPTActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py index f66c4173fa52..f37d28812c27 100644 --- a/applications/Chat/coati/models/opt/opt_critic.py +++ b/applications/Chat/coati/models/opt/opt_critic.py @@ -18,12 +18,14 @@ class OPTCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py index 6f75344e6aae..893708344ad4 100644 --- a/applications/Chat/coati/models/opt/opt_rm.py +++ b/applications/Chat/coati/models/opt/opt_rm.py @@ -17,11 +17,13 @@ class OPTRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 97637d3523b0..def6190dd71c 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -4,9 +4,9 @@ import torch.nn.functional as F -def _compute_approx_kl(log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def _compute_approx_kl( + log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None +) -> torch.Tensor: """ Compute the approximate KL divergence between two distributions. Schulman blog: http://joschu.net/blog/kl-approx.html @@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor, return approx_kl -def compute_reward(r: Union[torch.Tensor, float], - kl_coef: float, - log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def compute_reward( + r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: if kl_coef <= 0.0: return r kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) @@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num Returns: torch.Tensor: Action log probs. """ - logits = output['logits'] + logits = output["logits"] log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py index a65a78d07bb8..1765b8091bc3 100644 --- a/applications/Chat/coati/quant/__init__.py +++ b/applications/Chat/coati/quant/__init__.py @@ -2,6 +2,6 @@ from .utils import low_resource_init __all__ = [ - 'llama_load_quant', - 'low_resource_init', + "llama_load_quant", + "low_resource_init", ] diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py index 51c8d6316290..51d5233586ad 100644 --- a/applications/Chat/coati/quant/llama_gptq/__init__.py +++ b/applications/Chat/coati/quant/llama_gptq/__init__.py @@ -1,5 +1,5 @@ from .loader import load_quant __all__ = [ - 'load_quant', + "load_quant", ] diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py index 5353dc8a2ea3..50486337a7ab 100644 --- a/applications/Chat/coati/quant/llama_gptq/loader.py +++ b/applications/Chat/coati/quant/llama_gptq/loader.py @@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int): # ignore lm head layers = find_layers(model) - for name in ['lm_head']: + for name in ["lm_head"]: if name in layers: del layers[name] make_quant(model, layers, wbits, groupsize) - if checkpoint.endswith('.safetensors'): + if checkpoint.endswith(".safetensors"): from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py index 62db171abb52..18e4e4761500 100644 --- a/applications/Chat/coati/quant/llama_gptq/model_utils.py +++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py @@ -1,13 +1,12 @@ # copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py -import torch import torch.nn as nn -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): if type(module) in layers: return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py index f7d5b7ce4bd8..5a7e2e72dfc5 100644 --- a/applications/Chat/coati/quant/llama_gptq/quant.py +++ b/applications/Chat/coati/quant/llama_gptq/quant.py @@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq): class Quantizer(nn.Module): - def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -68,7 +67,7 @@ def find_params(self, x, weight=False): self.zero = torch.round(-xmin / self.scale) if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin @@ -123,13 +122,12 @@ def ready(self): try: import quant_cuda except: - print('CUDA extension not installed.') + print("CUDA extension not installed.") # Assumes layer is perfectly divisible into 256 * 256 blocks class QuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures): super().__init__() if bits not in [2, 3, 4, 8]: @@ -142,11 +140,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures): groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize self.register_buffer( - 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), - dtype=torch.int)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) - self.register_buffer('bias', torch.zeros(outfeatures)) - self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) + "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) + ) + self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) + self.register_buffer("bias", torch.zeros(outfeatures)) + self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) self._initialized_quant_state = False def pack(self, linear, scales, zeros): @@ -161,8 +159,10 @@ def pack(self, linear, scales, zeros): for idx in range(self.infeatures): g_idx = idx // self.groupsize intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) @@ -271,13 +271,13 @@ def forward(self, x): return y.reshape(outshape) -def make_quant(module, names, bits, groupsize, name=''): +def make_quant(module, names, bits, groupsize, name=""): if isinstance(module, QuantLinear): return for attr in dir(module): tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr + name1 = name + "." + attr if name != "" else attr if name1 in names: setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) for name1, child in module.named_children(): - make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1) diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py index 01b8cff0add1..d102bb30f52d 100644 --- a/applications/Chat/coati/quant/utils.py +++ b/applications/Chat/coati/quant/utils.py @@ -9,8 +9,7 @@ def _noop(*args, **kwargs): @contextmanager def low_resource_init(): - """This context manager disables weight initialization and sets the default float dtype to half. - """ + """This context manager disables weight initialization and sets the default float dtype to half.""" old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ old_uniform_ = torch.nn.init.uniform_ old_normal_ = torch.nn.init.normal_ diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py index 3306150a41ff..8c5bd8a67776 100644 --- a/applications/Chat/coati/ray/callbacks/base.py +++ b/applications/Chat/coati/ray/callbacks/base.py @@ -5,7 +5,7 @@ class TrainerCallback(ABC): """ - Base callback class. It defines the interface for callbacks. + Base callback class. It defines the interface for callbacks. """ def on_fit_start(self) -> None: @@ -40,7 +40,6 @@ def on_update_end(self) -> None: class MakerCallback(ABC): - def on_loop_start(self) -> None: pass diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py index d3df8f9ae3e0..18798bce7dce 100644 --- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py @@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -42,13 +41,13 @@ def end(self) -> None: self.duration += time() - self.start_time def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class ExperienceMakerPerformanceEvaluator(MakerCallback): - - def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, - reward_model_num_params: int) -> None: + def __init__( + self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -63,7 +62,7 @@ def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_ self.make_experience_flop: int = 0 print_rank_0( - f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}' + f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}" ) def on_make_experience_start(self) -> None: @@ -110,27 +109,29 @@ def on_loop_end(self) -> None: avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size) - avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \ - (self.total_samples * self.world_size) + avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / ( + self.total_samples * self.world_size + ) avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' - + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' - + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' - + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' - - + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' + "Making Experience Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n" ) class TrainerPerformanceEvaluator(TrainerCallback): - - def __init__(self, - actor_num_params: int, - critic_num_params: int, - enable_grad_checkpoint: bool = False, - ignore_first_episodes: int = 1) -> None: + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_first_episodes: int = 1, + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -146,7 +147,7 @@ def __init__(self, self.learn_flop: int = 0 print_rank_0( - f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}' + f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}" ) def on_episode_start(self, episodes: int) -> None: @@ -191,7 +192,7 @@ def on_update_end(self) -> None: def on_fit_end(self) -> None: if self.total_samples == 0: - print_rank_0('No samples are collected, skip trainer performance evaluation') + print_rank_0("No samples are collected, skip trainer performance evaluation") return avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size) @@ -204,9 +205,10 @@ def on_fit_end(self) -> None: avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' - + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' - + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' - - + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' + "Learning Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n" ) diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py index e04bf5ccb881..92dab17292f7 100644 --- a/applications/Chat/coati/ray/detached_replay_buffer.py +++ b/applications/Chat/coati/ray/detached_replay_buffer.py @@ -1,20 +1,15 @@ -import asyncio -import copy -import random -from threading import Lock -from typing import Any, List +from typing import List -import ray import torch -from coati.experience_buffer import ExperienceBuffer from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_maker.base import Experience + # from torch.multiprocessing import Queue from ray.util.queue import Queue class DetachedReplayBuffer: - ''' + """ Detached replay buffer. Share Experience across workers on the same node. Therefore, a trainer node is expected to have only one instance. It is ExperienceMakerHolder's duty to call append(exp) method, remotely. @@ -24,7 +19,7 @@ class DetachedReplayBuffer: tp_world_size: Number of workers in the same tp group limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. - ''' + """ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: self.sample_batch_size = sample_batch_size @@ -34,23 +29,23 @@ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: @torch.no_grad() def append(self, experience: Experience) -> None: - ''' + """ Expected to be called remotely. - ''' + """ items = split_experience_batch(experience) self.extend(items) @torch.no_grad() def extend(self, items: List[BufferItem]) -> None: - ''' + """ Expected to be called remotely. - ''' + """ self.batch_collector.extend(items) while len(self.batch_collector) >= self.sample_batch_size: - items = self.batch_collector[:self.sample_batch_size] + items = self.batch_collector[: self.sample_batch_size] experience = make_experience_batch(items) self.items.put(experience, block=True) - self.batch_collector = self.batch_collector[self.sample_batch_size:] + self.batch_collector = self.batch_collector[self.sample_batch_size :] def clear(self) -> None: # self.items.close() diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py index 90399781187a..fcf0a472df9e 100644 --- a/applications/Chat/coati/ray/detached_trainer_base.py +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -1,6 +1,6 @@ import os from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, List import ray import torch @@ -15,7 +15,7 @@ class DetachedTrainer(ABC): - ''' + """ Base class for detached rlhf trainers. 'detach' means that the experience maker is detached compared to a normal Trainer. Please set name attribute during init: @@ -28,15 +28,17 @@ class DetachedTrainer(ABC): callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating - ''' - - def __init__(self, - experience_maker_holder_name_list: List[str], - train_batch_size: int = 8, - buffer_limit: int = 0, - dataloader_pin_memory: bool = True, - callbacks: List[TrainerCallback] = [], - debug: bool = False) -> None: + """ + + def __init__( + self, + experience_maker_holder_name_list: List[str], + train_batch_size: int = 8, + buffer_limit: int = 0, + dataloader_pin_memory: bool = True, + callbacks: List[TrainerCallback] = [], + debug: bool = False, + ) -> None: super().__init__() self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) self.dataloader_pin_memory = dataloader_pin_memory @@ -67,18 +69,16 @@ def training_step(self, experience: Experience) -> Dict[str, Any]: def _learn(self, update_steps: int, train_epochs: int) -> None: data = [] # warmup - pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0()) + pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0()) self._on_epoch_start(0) self._learn_epoch(pbar, data) self._on_epoch_end(0) # item is already a batch - dataloader = DataLoader(data, - batch_size=1, - shuffle=True, - pin_memory=self.dataloader_pin_memory, - collate_fn=lambda x: x[0]) + dataloader = DataLoader( + data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0] + ) for epoch in range(1, train_epochs): - pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0()) + pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0()) self._on_epoch_start(epoch) self._learn_epoch(pbar, data) self._on_epoch_end(epoch) @@ -104,7 +104,7 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: self._on_fit_start() - for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): + for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()): self._on_episode_start(i) self._learn(update_steps, train_epochs) self._on_update_start() diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index 2f2aa0e29579..ef84a1ddba48 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -1,12 +1,11 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple import ray import torch -from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.experience_maker import Experience from coati.models.base import Actor, Critic from coati.models.loss import PolicyLoss, ValueLoss -from coati.trainer.callbacks import Callback -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy +from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy from torch.optim import Adam from colossalai.nn.optimizer import HybridAdam @@ -14,27 +13,14 @@ from .callbacks import TrainerCallback, TrainerPerformanceEvaluator from .detached_trainer_base import DetachedTrainer from .lora_constructor import LoRAConstructor -from .utils import ( - get_actor_from_args, - get_critic_from_args, - get_model_numel, - get_rank, - get_strategy_from_args, - is_rank_0, - set_dist_env, - state_dict_to, -) +from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to -@ray.remote(concurrency_groups={ - "buffer_length": 1, - "buffer_append": 1, - "buffer_sample": 1, - "model_io": 1, - "compute": 1 -}) +@ray.remote( + concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1} +) class DetachedPPOTrainer(DetachedTrainer): - ''' + """ Detached Trainer for PPO algorithm Args: strategy (Strategy): the strategy to use for training @@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer): dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating - ''' + """ def __init__( self, @@ -92,21 +78,24 @@ def __init__( self.actor_optim = Adam(self.actor.parameters(), lr=1e-7) self.critic_optim = Adam(self.critic.parameters(), lr=1e-7) - (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ - self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) + (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare( + (self.actor, self.actor_optim), (self.critic, self.critic_optim) + ) # configure trainer self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) - super().__init__(experience_maker_holder_name_list, - train_batch_size=train_batch_size, - buffer_limit=buffer_limit, - dataloader_pin_memory=dataloader_pin_memory, - callbacks=callbacks, - debug=debug) + super().__init__( + experience_maker_holder_name_list, + train_batch_size=train_batch_size, + buffer_limit=buffer_limit, + dataloader_pin_memory=dataloader_pin_memory, + callbacks=callbacks, + debug=debug, + ) if self._debug: - print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}') + print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}") self._update_lora_weights = update_lora_weights @@ -115,7 +104,7 @@ def __init__( def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties if not fully_update: - config['requires_grad_only'] = True + config["requires_grad_only"] = True self.update_target_holder_list() # mark start, ensure order tasks = [] @@ -131,7 +120,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config): target_holder.update_experience_maker.remote( new_actor_state_dict=state_dict_shard, new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor), - fully_update=fully_update)) + fully_update=fully_update, + ) + ) # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config): for target_holder in self.target_holder_list: @@ -139,7 +130,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config): target_holder.update_experience_maker.remote( new_critic_state_dict=state_dict_shard, new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic), - fully_update=fully_update)) + fully_update=fully_update, + ) + ) ray.get(tasks) # mark end for target_holder in self.target_holder_list: @@ -152,26 +145,24 @@ def training_step(self, experience: Experience) -> Dict[str, float]: num_actions = experience.action_mask.size(1) action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim) self.actor_optim.zero_grad() - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self.critic( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self.critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()} def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: self.strategy.save_model(self.actor, path, only_rank0) diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 13314bdafd5f..4d290f4aba88 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -1,53 +1,49 @@ import os import time import tracemalloc -from copy import deepcopy from threading import Lock -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union import ray import torch -import torch.nn as nn -from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch -from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker +from coati.experience_buffer.utils import split_experience_batch +from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic, RewardModel -from coati.trainer.callbacks import Callback from coati.trainer.strategies import Strategy -from coati.trainer.strategies.sampler import DistributedSampler -from ray.exceptions import GetTimeoutError from torch import Tensor from tqdm import tqdm from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback from .lora_constructor import LoRAConstructor -from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to +from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) class ExperienceMakerHolder: - ''' + """ Args: detached_trainer_name_list: str list to get ray actor handles strategy: kl_coef: the coefficient of kl divergence loss sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. - ''' + """ def __init__( - self, - detached_trainer_name_list: List[str], - strategy_fn: Callable[[], Strategy], + self, + detached_trainer_name_list: List[str], + strategy_fn: Callable[[], Strategy], # a function returns (actor, critic, reward_model, initial_model) - model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], - env_info: Dict[str, str] = None, - sync_models_from_trainers: bool = False, - buffer_cpu_offload: bool = True, - kl_coef: float = 0.1, - callbacks: List[MakerCallback] = [], - eval_performance: bool = False, - debug: bool = False, - update_lora_weights: bool = False, - **generate_kwargs): + model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], + env_info: Dict[str, str] = None, + sync_models_from_trainers: bool = False, + buffer_cpu_offload: bool = True, + kl_coef: float = 0.1, + callbacks: List[MakerCallback] = [], + eval_performance: bool = False, + debug: bool = False, + update_lora_weights: bool = False, + **generate_kwargs, + ): # set environment variables if env_info: set_dist_env(env_info=env_info) @@ -66,8 +62,9 @@ def __init__( critic_numel = get_model_numel(critic) initial_model_numel = get_model_numel(initial_model) reward_model_numel = get_model_numel(reward_model) - evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, - reward_model_numel) + evaluator = ExperienceMakerPerformanceEvaluator( + actor_numel, critic_numel, initial_model_numel, reward_model_numel + ) callbacks = callbacks + [evaluator] actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) @@ -89,9 +86,9 @@ def __init__( self._target_idx = 0 if self._debug: - print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}') + print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}") if not self._is_fully_initialized: - print(f'[maker{get_rank()}] Waiting for INIT') + print(f"[maker{get_rank()}] Waiting for INIT") def _get_ready(self): while not self._fully_initialized(): @@ -136,7 +133,7 @@ def _inference_step(self, batch) -> None: self._on_make_experience_end(experience) self._on_send_start() if self.buffer_cpu_offload: - experience.to_device('cpu') + experience.to_device("cpu") self._send_items(experience) self._on_send_end() self._on_batch_end() @@ -155,7 +152,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 if num_steps > 0: # ignore num epochs it = iter(dataloader) - for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()): + for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()): try: batch = next(it) except StopIteration: @@ -163,7 +160,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 batch = next(it) self._inference_step(batch) else: - with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar: + with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar: for _ in range(num_epochs): for batch in dataloader: self._inference_step(batch) @@ -171,22 +168,24 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 self._on_loop_end() @ray.method(concurrency_group="model_io") - def update_experience_maker(self, - new_actor_state_dict: Dict[str, Any] = None, - new_actor_lora_config_dict: Dict[str, Any] = None, - new_critic_state_dict: Dict[str, Any] = None, - new_critic_lora_config_dict: Dict[str, Any] = None, - fully_update: bool = False, - chunk_start: bool = None, - chunk_end: bool = None): - ''' - called by trainer - chunk_start: Set True at the first call. Before sending state_dict calls - chunk_end: Set True at the last call. After sending state_dict calls. - fully_update: Set True if you want to sync models when initializing - - TODO: load_state_dict integrate with model-sharding strategy - ''' + def update_experience_maker( + self, + new_actor_state_dict: Dict[str, Any] = None, + new_actor_lora_config_dict: Dict[str, Any] = None, + new_critic_state_dict: Dict[str, Any] = None, + new_critic_lora_config_dict: Dict[str, Any] = None, + fully_update: bool = False, + chunk_start: bool = None, + chunk_end: bool = None, + ): + """ + called by trainer + chunk_start: Set True at the first call. Before sending state_dict calls + chunk_end: Set True at the last call. After sending state_dict calls. + fully_update: Set True if you want to sync models when initializing + + TODO: load_state_dict integrate with model-sharding strategy + """ _watch_memory = self._debug if chunk_start: if self._debug: @@ -202,18 +201,22 @@ def update_experience_maker(self, else: new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) state_dict_increase = self.actor_lora_constructor.reconstruct_increase( - new_actor_state_dict, new_actor_lora_config_dict) + new_actor_state_dict, new_actor_lora_config_dict + ) self.actor_lora_constructor.load_state_dict_increase( - self.experience_maker.actor.model, state_dict_increase) + self.experience_maker.actor.model, state_dict_increase + ) if new_critic_state_dict is not None: if not self._update_lora_weights or fully_update: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) else: new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) state_dict_increase = self.critic_lora_constructor.reconstruct_increase( - new_critic_state_dict, new_critic_lora_config_dict) + new_critic_state_dict, new_critic_lora_config_dict + ) self.critic_lora_constructor.load_state_dict_increase( - self.experience_maker.critic, state_dict_increase) + self.experience_maker.critic, state_dict_increase + ) # the lock must be released after both actor and critic being updated if chunk_end: @@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None: origin_model = actor.model new_kwargs = {**generate_kwargs} # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation return new_kwargs diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py index a98545d4d751..8e9f78700e29 100644 --- a/applications/Chat/coati/ray/lora_constructor.py +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -1,11 +1,9 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict -import torch import torch.nn as nn from coati.models.lora import LoraLinear -from loralib.layers import LoRALayer @dataclass @@ -17,7 +15,7 @@ class LoRAConfig: class LoRAConstructor: - ''' + """ Tools for reconstructing a model from a remote LoRA model. (Transferring only LoRA data costs much less!) Usage: @@ -36,7 +34,7 @@ class LoRAConstructor: Step 5 (Receiver): load_state_dict_increase() - ''' + """ def __init__(self): self.lora_config_dict = None @@ -45,10 +43,10 @@ def register_lora_config(self, lora_config_dict: Dict[str, Any]): self.lora_config_dict = lora_config_dict def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]): - ''' - xxx.lora_A, xxx.lora_B -->> xxx.weight - Warning: the xxx.weight here is the increment actually. - ''' + """ + xxx.lora_A, xxx.lora_B -->> xxx.weight + Warning: the xxx.weight here is the increment actually. + """ if lora_config_dict is not None: self.register_lora_config(lora_config_dict) @@ -56,24 +54,25 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict config_iter = iter(self.lora_config_dict.items()) lora_A, lora_B, layer_prefix = None, None, None for k, v in state_dict_lora.items(): - if k.rpartition('.')[-1] == 'lora_A': + if k.rpartition(".")[-1] == "lora_A": lora_A = v - layer_prefix = k.rpartition('.')[0] - elif k.rpartition('.')[-1] == 'lora_B': - assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" + layer_prefix = k.rpartition(".")[0] + elif k.rpartition(".")[-1] == "lora_B": + assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair" layer_prefix_2, config = next(config_iter) assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" lora_B = v weight_data_increase = self._compute(lora_A, lora_B, config) - state_dict_increase[layer_prefix + '.weight'] = weight_data_increase + state_dict_increase[layer_prefix + ".weight"] = weight_data_increase lora_A, lora_B, layer_prefix = None, None, None else: - raise ValueError('unexpected key') + raise ValueError("unexpected key") return state_dict_increase def _compute(self, lora_A, lora_B, config=LoRAConfig()): def T(w): return w.T if config.fan_in_fan_out else w + if config.r > 0: scaling = config.lora_alpha / config.r weight_data_increase = T(lora_B @ lora_A) * scaling @@ -81,21 +80,21 @@ def T(w): return 0 def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]): - ''' + """ The final reconstruction step - ''' + """ # naive approach model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False) @staticmethod def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): - ''' + """ if keep_non_lora, also return non_lora state_dict - ''' + """ state_dict_lora = OrderedDict() state_dict_non_lora = OrderedDict() for k, v in state_dict.items(): - if 'lora_A' in k or 'lora_B' in k: + if "lora_A" in k or "lora_B" in k: state_dict_lora[k] = v elif keep_non_lora: state_dict_non_lora[k] = v @@ -106,17 +105,19 @@ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): @staticmethod def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: - ''' + """ extract LoraLinear model. return OrderedDict(): name -> LoRAConfig - ''' + """ lora_config_dict = OrderedDict() for name, child in model.named_modules(): if isinstance(child, LoraLinear): - lora_config_dict[name] = LoRAConfig(r=child.r, - lora_alpha=child.lora_alpha, - lora_dropout=child.lora_dropout, - fan_in_fan_out=child.fan_in_fan_out) + lora_config_dict[name] = LoRAConfig( + r=child.r, + lora_alpha=child.lora_alpha, + lora_dropout=child.lora_dropout, + fan_in_fan_out=child.fan_in_fan_out, + ) return lora_config_dict diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 391ffe7a91a9..036dd145dddb 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -1,6 +1,6 @@ import os from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict import torch import torch.distributed as dist @@ -10,7 +10,7 @@ from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer def is_rank_0() -> bool: @@ -26,13 +26,13 @@ def get_world_size() -> int: def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): - if model == 'gpt2': + if model == "gpt2": actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'bloom': + elif model == "bloom": actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'opt': + elif model == "opt": actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'llama': + elif model == "llama": actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank) else: raise ValueError(f'Unsupported actor model "{model}"') @@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): - if model == 'gpt2': + if model == "gpt2": critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'bloom': + elif model == "bloom": critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'opt': + elif model == "opt": critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'llama': + elif model == "llama": critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{model}"') @@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r def get_reward_model_from_args(model: str, pretrained: str = None, config=None): - if model == 'gpt2': + if model == "gpt2": reward_model = GPTRM(pretrained=pretrained, config=config) - elif model == 'bloom': + elif model == "bloom": reward_model = BLOOMRM(pretrained=pretrained, config=config) - elif model == 'opt': + elif model == "opt": reward_model = OPTRM(pretrained=pretrained, config=config) - elif model == 'llama': + elif model == "llama": reward_model = LlamaRM(pretrained=pretrained, config=config) else: raise ValueError(f'Unsupported reward model "{model}"') @@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None): def get_strategy_from_args(strategy: str): - if strategy == 'ddp': + if strategy == "ddp": strategy_ = DDPStrategy() - elif strategy == 'colossalai_gemini': - strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif strategy == 'colossalai_gemini_cpu': - strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif strategy == 'colossalai_zero2_cpu': - strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif strategy == "colossalai_gemini": + strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy == "colossalai_zero2": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif strategy == "colossalai_gemini_cpu": + strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif strategy == "colossalai_zero2_cpu": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{strategy}"') return strategy_ def get_tokenizer_from_args(model: str, **kwargs): - if model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - elif model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') - elif model == 'opt': + if model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + elif model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") + elif model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif model == 'llama': + elif model == "llama": pretrain_path = kwargs["pretrain"] tokenizer = AutoTokenizer.from_pretrained(pretrain_path) else: @@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs): def set_dist_env(env_info: Dict[str, str]): - os.environ["RANK"] = env_info['rank'] - os.environ["LOCAL_RANK"] = env_info['local_rank'] - os.environ["WORLD_SIZE"] = env_info['world_size'] - os.environ['MASTER_PORT'] = env_info['master_port'] - os.environ['MASTER_ADDR'] = env_info['master_addr'] + os.environ["RANK"] = env_info["rank"] + os.environ["LOCAL_RANK"] = env_info["local_rank"] + os.environ["WORLD_SIZE"] = env_info["world_size"] + os.environ["MASTER_PORT"] = env_info["master_port"] + os.environ["MASTER_ADDR"] = env_info["master_addr"] def get_model_numel(model: nn.Module) -> int: @@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i return target_receivers -def state_dict_to(state_dict: Dict[str, Any], - dtype: torch.dtype = torch.float16, - device: torch.device = torch.device('cpu')): - ''' - keep state_dict intact - ''' +def state_dict_to( + state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu") +): + """ + keep state_dict intact + """ new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v.to(dtype=dtype, device=device) diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py index 86142361f3ff..4be5d27f93b1 100644 --- a/applications/Chat/coati/trainer/__init__.py +++ b/applications/Chat/coati/trainer/__init__.py @@ -3,8 +3,4 @@ from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = [ - 'SLTrainer', 'OnPolicyTrainer', - 'RewardModelTrainer', 'SFTTrainer', - 'PPOTrainer' -] +__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index 0629c9c00cca..ca450edee0c3 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC): callbacks (List[Callback], defaults to []): the callbacks to call during training process """ - def __init__(self, - strategy: Strategy, - data_buffer: NaiveExperienceBuffer, - sample_buffer: bool, - dataloader_pin_memory: bool, - callbacks: List[Callback] = []) -> None: + def __init__( + self, + strategy: Strategy, + data_buffer: NaiveExperienceBuffer, + sample_buffer: bool, + dataloader_pin_memory: bool, + callbacks: List[Callback] = [], + ) -> None: super().__init__() self.strategy = strategy self.data_buffer = data_buffer diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py index 9ed0ee6f7640..29c8c4f00a5c 100644 --- a/applications/Chat/coati/trainer/callbacks/__init__.py +++ b/applications/Chat/coati/trainer/callbacks/__init__.py @@ -2,4 +2,4 @@ from .performance_evaluator import PerformanceEvaluator from .save_checkpoint import SaveCheckpoint -__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] +__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"] diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py index f5616048855b..d5181175b324 100644 --- a/applications/Chat/coati/trainer/callbacks/base.py +++ b/applications/Chat/coati/trainer/callbacks/base.py @@ -5,7 +5,7 @@ class Callback(ABC): """ - Base callback class. It defines the interface for callbacks. + Base callback class. It defines the interface for callbacks. """ def on_fit_start(self) -> None: diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 9b44dafa7eaa..c2eda92cc165 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None: def divide(x: float, y: float) -> float: if y == 0: - return float('inf') - elif y == float('inf'): - return float('nan') + return float("inf") + elif y == float("inf"): + return float("nan") return x / y @@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -52,7 +51,7 @@ def end(self) -> None: self.start_time = None def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class PerformanceEvaluator(Callback): @@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback): ignore_episodes: The number of episodes to ignore when calculating the performance. """ - def __init__(self, - actor_num_params: int, - critic_num_params: int, - initial_model_num_params: int, - reward_model_num_params: int, - enable_grad_checkpoint: bool = False, - ignore_episodes: int = 0) -> None: + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + initial_model_num_params: int, + reward_model_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_episodes: int = 0, + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -155,8 +156,9 @@ def on_fit_end(self) -> None: avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) - avg_make_experience_throughput = self.make_experience_num_samples * \ - self.world_size / (avg_make_experience_duration + 1e-12) + avg_make_experience_throughput = ( + self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12) + ) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) @@ -171,13 +173,11 @@ def on_fit_end(self) -> None: learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) print_rank_0( - f'Performance summary:\n' - + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' - - + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' - + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' - + f'Overall time per sample: {overall_time_per_sample:.2f} s\n' - + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' - - + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' + f"Performance summary:\n" + + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n" + + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n" + + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n" + + f"Overall time per sample: {overall_time_per_sample:.2f} s\n" + + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n" + + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%" ) diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py index f0d77a191a88..0d70b6c53073 100644 --- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py +++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py @@ -36,34 +36,35 @@ class SaveCheckpoint(Callback): """ - def __init__(self, - path: str, - interval: int, - strategy: Strategy, - actor: nn.Module = None, - critic: nn.Module = None, - actor_optim: Optimizer = None, - critic_optim: Optimizer = None) -> None: + def __init__( + self, + path: str, + interval: int, + strategy: Strategy, + actor: nn.Module = None, + critic: nn.Module = None, + actor_optim: Optimizer = None, + critic_optim: Optimizer = None, + ) -> None: super().__init__() - self.path = os.path.join(path, 'checkpoint') + self.path = os.path.join(path, "checkpoint") self.interval = interval self.strategy = strategy - self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} + self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]} def on_episode_end(self, episode: int) -> None: if (episode + 1) % self.interval != 0: return - base_path = os.path.join(self.path, f'episode_{episode}') + base_path = os.path.join(self.path, f"episode_{episode}") if not os.path.exists(base_path): os.makedirs(base_path) for model in self.model_dict.keys(): - # save model if self.model_dict[model][0] is None: # saving only optimizer states is meaningless, so it would be skipped continue - model_path = os.path.join(base_path, f'{model}.pt') + model_path = os.path.join(base_path, f"{model}.pt") self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) # save optimizer @@ -71,5 +72,5 @@ def on_episode_end(self, episode: int) -> None: continue only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)) rank = 0 if is_rank_0() else dist.get_rank() - optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') + optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt") self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index ef625a1c1b3d..6f255a935d91 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -8,7 +8,7 @@ from coati.models.utils import calc_action_log_probs from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DistributedSampler from tqdm import tqdm from colossalai.utils import get_current_device @@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto hf_model = get_base_model(unwrapper_model) new_kwargs = {**generate_kwargs} # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation + if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation return new_kwargs @@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer): generate_kwargs (dict, optional): the kwargs to use while model generating """ - def __init__(self, - strategy: Strategy, - actor: Actor, - critic: Critic, - reward_model: nn.Module, - initial_model: Actor, - actor_optim: Optimizer, - critic_optim: Optimizer, - kl_coef: float = 0.1, - ptx_coef: float = 0.9, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - vf_coef: float = 1.0, - value_clip: float = 0.4, - sample_buffer: bool = False, - dataloader_pin_memory: bool = True, - offload_inference_models: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs - ) -> None: + def __init__( + self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: nn.Module, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.4, + sample_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs, + ) -> None: if isinstance(strategy, GeminiStrategy): - assert not offload_inference_models, \ - "GeminiPlugin is not compatible with manual model.to('cpu')" + assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) - super().__init__( - strategy, data_buffer, - sample_buffer, dataloader_pin_memory, - callbacks - ) + super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) @@ -130,18 +126,16 @@ def _training_step(self, experience: Experience) -> Dict[str, float]: num_actions = experience.action_mask.size(1) actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) # ptx loss if self.ptx_coef != 0: batch = self.pretrain_dataloader.next() batch = to_device(batch, self.device) - ptx_log_probs = self.actor(batch['input_ids'], - attention_mask=batch['attention_mask'])['logits'] - ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) + ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"]) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) self.strategy.backward(actor_loss, self.actor, self.actor_optim) @@ -149,24 +143,23 @@ def _training_step(self, experience: Experience) -> Dict[str, float]: self.actor_optim.zero_grad() # value loss - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self.critic( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self.critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) critic_loss = critic_loss * self.vf_coef self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {'reward': experience.reward.mean().item()} + return {"reward": experience.reward.mean().item()} def _learn(self, update_step: int): if self.offload_inference_models: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') + self.experience_maker.initial_model.to("cpu") + self.experience_maker.reward_model.to("cpu") # buffer may be empty at first, we should rebuild at each training if self.sample_buffer: @@ -178,11 +171,7 @@ def _learn(self, update_step: int): else: if isinstance(self.dataloader.sampler, DistributedSampler): self.dataloader.sampler.set_epoch(update_step) - pbar = tqdm( - self.dataloader, - desc=f'Train epoch [{update_step + 1}]', - disable=not is_rank_0() - ) + pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) for experience in pbar: self._on_learn_batch_start() experience.to_device(self.device) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index 54a5d0f40dea..a5d6974b3238 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -62,18 +62,15 @@ def _eval(self, epoch): if is_rank_0(): log = pd.DataFrame( - [[(epoch + 1) * len(self.train_dataloader), - self.loss.item(), self.dist, self.acc]], - columns=['step', 'loss', 'dist', 'acc'] + [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]], + columns=["step", "loss", "dist", "acc"], ) - log.to_csv('log.csv', mode='a', header=False, index=False) + log.to_csv("log.csv", mode="a", header=False, index=False) def _train(self, epoch): self.model.train() step_bar = tqdm.trange( - len(self.train_dataloader), - desc='Train step of epoch %d' % epoch, - disable=not is_rank_0() + len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0() ) cnt = 0 for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: @@ -93,10 +90,7 @@ def _train(self, epoch): step_bar.update() step_bar.close() - def _before_fit(self, - train_dataloader: DataLoader, - valid_dataloader: DataLoader, - eval_dataloader: DataLoader): + def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader): """ Args: train_dataloader (DataLoader): the dataloader to use for training @@ -104,7 +98,7 @@ def _before_fit(self, eval_dataloader (DataLoader): the dataloader to use for evaluation """ super()._before_fit() - self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") self.train_dataloader = train_dataloader self.valid_dataloader = valid_dataloader diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index e4d0a970740d..8deefc2c484e 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -39,8 +39,9 @@ def __init__( accumulation_steps: int = 8, ) -> None: if accumulation_steps > 1: - assert not isinstance(strategy, GeminiStrategy), \ - "Accumulation steps are not supported in stage 3 of ColossalAI" + assert not isinstance( + strategy, GeminiStrategy + ), "Accumulation steps are not supported in stage 3 of ColossalAI" super().__init__(strategy, max_epochs, model, optim) @@ -50,15 +51,11 @@ def __init__( def _train(self, epoch: int): self.model.train() for batch_id, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) if "attention_mask" in batch: - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) else: - outputs = self.model(batch["input_ids"], - labels=batch["labels"]) + outputs = self.model(batch["input_ids"], labels=batch["labels"]) loss = outputs.loss loss = loss / self.accumulation_steps @@ -73,12 +70,14 @@ def _train(self, epoch: int): self.optimizer.zero_grad() self.scheduler.step() if is_rank_0() and self.use_wandb: - wandb.log({ - "loss": self.total_loss / self.accumulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id - }) + wandb.log( + { + "loss": self.total_loss / self.accumulation_steps, + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id, + } + ) self.total_loss = 0 self.step_bar.update() @@ -89,9 +88,9 @@ def _eval(self, epoch: int): loss_sum, num_seen = 0, 0 for batch in self.eval_dataloader: batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] + ) loss = outputs.loss loss_sum += loss.item() @@ -99,13 +98,15 @@ def _eval(self, epoch: int): loss_mean = loss_sum / num_seen if dist.get_rank() == 0: - self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') + self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") - def _before_fit(self, - train_dataloader: DataLoader, - eval_dataloader: Optional[DataLoader] = None, - logger: Optional[DistributedLogger] = None, - use_wandb: bool = False): + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader] = None, + logger: Optional[DistributedLogger] = None, + use_wandb: bool = False, + ): """ Args: train_dataloader: the dataloader to use for training @@ -124,6 +125,6 @@ def _before_fit(self, self.no_epoch_bar = True self.step_bar = tqdm.trange( len(self.train_dataloader) // self.accumulation_steps * self.max_epochs, - desc=f'steps', - disable=not is_rank_0() + desc=f"steps", + disable=not is_rank_0(), ) diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py index b49a2c742db3..521dcb5855b1 100644 --- a/applications/Chat/coati/trainer/strategies/__init__.py +++ b/applications/Chat/coati/trainer/strategies/__init__.py @@ -2,7 +2,4 @@ from .colossalai import GeminiStrategy, LowLevelZeroStrategy from .ddp import DDPStrategy -__all__ = [ - 'Strategy', 'DDPStrategy', - 'LowLevelZeroStrategy', 'GeminiStrategy' -] +__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index c20b2b16e396..303d4bc220a6 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -19,7 +19,7 @@ class Strategy(ABC): """ - Base class for training strategies. + Base class for training strategies. """ def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None: @@ -83,16 +83,18 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo rets.append((model, optimizer)) elif isinstance(arg, Dict): model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) - boost_result = dict(model=model, - optimizer=optimizer, - criterion=criterion, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + boost_result = dict( + model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=dataloader, + lr_scheduler=lr_scheduler, + ) # remove None values boost_result = {key: value for key, value in boost_result.items() if value is not None} rets.append(boost_result) else: - raise RuntimeError(f'Type {type(arg)} is not supported') + raise RuntimeError(f"Type {type(arg)} is not supported") return rets[0] if len(rets) == 1 else rets @@ -125,11 +127,9 @@ def setup_sampler(self, dataset) -> DistributedSampler: return DistributedSampler(dataset, 1, 0) @abstractmethod - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_pretrained( + self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: pass @abstractmethod diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index fa55f97ad661..4706f9699c91 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy): """ - def __init__(self, - stage: int = 2, - precision: str = 'fp16', - seed: int = 42, - placement_policy: str = 'cuda', - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: - + def __init__( + self, + stage: int = 2, + precision: str = "fp16", + seed: int = 42, + placement_policy: str = "cuda", + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: assert stage in (1, 2), f'Unsupported stage "{stage}"' - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' - assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' + assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' + assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' plugin_initializer = lambda: LowLevelZeroPlugin( # zero_config @@ -71,7 +71,7 @@ def __init__(self, # zero_optim_config reduce_bucket_size_in_m=reduce_bucket_size, overlap_communication=overlap_communication, - cpu_offload=(placement_policy == 'cpu'), + cpu_offload=(placement_policy == "cpu"), # optim_config initial_scale=initial_scale, growth_factor=growth_factor, @@ -81,14 +81,15 @@ def __init__(self, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, - norm_type=norm_type + norm_type=norm_type, ) super().__init__(seed, plugin_initializer) def _post_init(self) -> None: - assert isinstance(self.plugin, LowLevelZeroPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance( + self.plugin, LowLevelZeroPlugin + ), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) @@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy): """ - def __init__(self, - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = 'cuda', - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - search_range_m: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_m: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: - - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + def __init__( + self, + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = "cuda", + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_m: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_m: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: + assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' # TODO(ver217): support shard_init when using from_pretrained() if shard_init: warnings.warn( - f'Shard init is not supported model.from_pretrained() yet. ' - 'Please load weights after strategy.prepare()' + f"Shard init is not supported model.from_pretrained() yet. " + "Please load weights after strategy.prepare()" ) self.shard_init = shard_init - warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') + warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( # gemini_config device=get_current_device(), placement_policy=placement_policy, - precision='fp16', + precision="fp16", pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, strict_ddp_mode=shard_init, @@ -187,14 +188,13 @@ def __init__(self, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, - norm_type=norm_type + norm_type=norm_type, ) super().__init__(seed, plugin_initializer) def _post_init(self) -> None: - assert isinstance(self.plugin, GeminiPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) @@ -203,10 +203,9 @@ def model_init_context(self): world_size = dist.get_world_size() shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_pg=shard_pg, - default_dist_spec=default_dist_spec) + return ColoInitContext( + device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec + ) def unwrap_model(self, model: nn.Module) -> nn.Module: assert isinstance(model, GeminiModel) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index a52b0460daa8..66ff6703da4d 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module): class DDPStrategy(Strategy): """ - Strategy for distributed training using torch.distributed. + Strategy for distributed training using torch.distributed. """ - def __init__(self, - seed: int = 42, - plugin_initializer: Callable = TorchDDPPlugin - ) -> None: + def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None: self.seed = seed super().__init__(plugin_initializer) def _try_init_dist(self, force: bool = False) -> None: try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) - dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank) torch.cuda.set_device(local_rank) except KeyError as e: if force: @@ -60,8 +57,7 @@ def _try_init_dist(self, force: bool = False) -> None: raise e def _post_init(self) -> None: - assert isinstance(self.plugin, TorchDDPPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: self._try_init_dist(force=True) @@ -73,12 +69,14 @@ def set_seed(self, seed: int) -> None: torch.manual_seed(seed) def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: - return self.plugin.prepare_dataloader(data_buffer, - batch_size=data_buffer.sample_batch_size, - shuffle=True, - drop_last=True, - pin_memory=pin_memory, - collate_fn=data_buffer.collate_fn) + return self.plugin.prepare_dataloader( + data_buffer, + batch_size=data_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=pin_memory, + collate_fn=data_buffer.collate_fn, + ) def setup_sampler(self, dataset) -> DistributedSampler: # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. @@ -88,11 +86,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module: assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel." return model.unwrap() - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_pretrained( + self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: if not only_rank0 or dist.get_rank() == 0: unwrapped_model = self.unwrap_model(model) assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) @@ -103,17 +99,11 @@ def save_pretrained(self, if tokenizer is not None: tokenizer.save_pretrained(path) model_path = os.path.join(path, "pytorch_model.bin") - self.save_model(model, - model_path, - only_rank0=only_rank0) + self.save_model(model, model_path, only_rank0=only_rank0) - def _replace_keys(model_path: str, - replace_fn: Callable): + def _replace_keys(model_path: str, replace_fn: Callable): state_dict = torch.load(model_path, map_location="cpu") - state_dict = { - replace_fn(k): v - for k, v in state_dict.items() - } + state_dict = {replace_fn(k): v for k, v in state_dict.items()} torch.save(state_dict, model_path) # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin @@ -124,13 +114,13 @@ def _replace_keys(model_path: str, def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy model = self.unwrap_model(model) - if 'requires_grad_only' in config and config['requires_grad_only'] == True: + if "requires_grad_only" in config and config["requires_grad_only"] == True: state_dict = get_grad_required_state_dict(model) else: state_dict = model.state_dict() - if 'shard_size' in config: - shard_size = config['shard_size'] + if "shard_size" in config: + shard_size = config["shard_size"] accumulate_size = 0 state_dict_shard = OrderedDict() for name, param in state_dict.items(): diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py index d726fa640fa2..6e811bef11a5 100644 --- a/applications/Chat/coati/trainer/strategies/sampler.py +++ b/applications/Chat/coati/trainer/strategies/sampler.py @@ -4,7 +4,6 @@ class DistributedSampler: - def __init__(self, dataset, num_replicas: int, rank: int) -> None: self.dataset = dataset self.num_replicas = num_replicas @@ -12,7 +11,7 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None: if len(self.dataset) % self.num_replicas != 0: self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) @@ -20,10 +19,10 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None: self.total_size = self.num_samples * self.num_replicas indices = list(range(len(self.dataset))) - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples self.indices = indices diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 7e2cb9c634f7..7811e7365eeb 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -42,7 +42,6 @@ def is_rank_0() -> bool: def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): if isinstance(t, torch.Tensor): return t.to(device) diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json index 023f16bef31c..4d30d005df30 100644 --- a/applications/Chat/evaluate/config/config_cn.json +++ b/applications/Chat/evaluate/config/config_cn.json @@ -70,7 +70,7 @@ "BLEU", "ROUGE", "BERTScore" - ] + ] }, "logical_reasoning": { "GPT": [ @@ -83,7 +83,7 @@ "ROUGE", "BERTScore", "CHRF" - ] + ] }, "open_qa": { "GPT": [ @@ -126,7 +126,7 @@ "conciseness" ], "Metrics": [ - ] + ] }, "Finance": { "GPT": [ @@ -134,7 +134,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Law": { "GPT": [ @@ -142,7 +142,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Education": { "GPT": [ @@ -150,7 +150,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Medical": { "GPT": [ @@ -158,7 +158,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "STEM": { "GPT": [ @@ -166,7 +166,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "SocialScience": { "GPT": [ @@ -174,7 +174,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Humanity": { "GPT": [ @@ -182,7 +182,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Other": { "GPT": [ @@ -190,7 +190,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "ethics": { "GPT": [ @@ -198,7 +198,7 @@ "correctness" ], "Metrics": [ - ] + ] } } } diff --git a/applications/Chat/evaluate/eval.py b/applications/Chat/evaluate/eval.py index e3fe0e9e091b..16ef31a94175 100644 --- a/applications/Chat/evaluate/eval.py +++ b/applications/Chat/evaluate/eval.py @@ -1,5 +1,4 @@ import argparse -import json import os import openai @@ -9,7 +8,8 @@ def main(args): assert len(args.answer_file_list) == len( - args.model_name_list), "The number of answer files and model names should be equal!" + args.model_name_list + ), "The number of answer files and model names should be equal!" # load config config = jload(args.config_file) @@ -36,7 +36,8 @@ def main(args): if len(args.model_name_list) == 1 and not gpt_evaluation_prompt: raise Exception( - "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!") + "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!" + ) if args.gpt_model == "text-davinci-003" and args.gpt_with_reference: raise Exception( @@ -44,8 +45,15 @@ def main(args): ) # initialize evaluator - evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model, - config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference) + evaluator = Evaluator( + metrics_per_category, + battle_prompt, + gpt_evaluation_prompt, + args.gpt_model, + config["language"], + config.get("path_for_UniEval", None), + args.gpt_with_reference, + ) if len(args.model_name_list) == 2: answers1 = jload(args.answer_file_list[0]) answers2 = jload(args.answer_file_list[1]) @@ -68,41 +76,41 @@ def main(args): raise ValueError(f'Unsupported language {config["language"]}!') -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.') - parser.add_argument('--config_file', - type=str, - default=None, - required=True, - help='path to the file of target results') - parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle') - parser.add_argument('--gpt_evaluation_prompt_file', - type=str, - default=None, - help='path to the prompt file for gpt evaluation') - parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file') - parser.add_argument('--answer_file_list', - type=str, - nargs='+', - default=[], - required=True, - help='path to the answer files of at most 2 models') - parser.add_argument('--model_name_list', - type=str, - nargs='+', - default=[], - required=True, - help='the names of at most 2 models') - parser.add_argument('--gpt_model', - default="gpt-3.5-turbo", - choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], - help='which GPT model to use for evaluation') - parser.add_argument('--gpt_with_reference', - default=False, - action="store_true", - help='whether to include reference answer in gpt evaluation') - parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results') - parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.") + parser.add_argument( + "--config_file", type=str, default=None, required=True, help="path to the file of target results" + ) + parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle") + parser.add_argument( + "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation" + ) + parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file") + parser.add_argument( + "--answer_file_list", + type=str, + nargs="+", + default=[], + required=True, + help="path to the answer files of at most 2 models", + ) + parser.add_argument( + "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models" + ) + parser.add_argument( + "--gpt_model", + default="gpt-3.5-turbo", + choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], + help="which GPT model to use for evaluation", + ) + parser.add_argument( + "--gpt_with_reference", + default=False, + action="store_true", + help="whether to include reference answer in gpt evaluation", + ) + parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results") + parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key") args = parser.parse_args() if args.openai_key is not None: diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py index 3dd5fd6f2f23..1d998cd2d09c 100644 --- a/applications/Chat/evaluate/evaluator.py +++ b/applications/Chat/evaluate/evaluator.py @@ -3,20 +3,27 @@ import gpt_evaluate import metrics -import pandas as pd import unieval from utils import analyze_automatic_results, get_data_per_category, save_automatic_results class Evaluator(object): """ - A class named Evaluator includes GPT-3.5/GPT-4 evaluation - and automatic evaluation + A class named Evaluator includes GPT-3.5/GPT-4 evaluation + and automatic evaluation """ - def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any], - gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None: + def __init__( + self, + params: Dict[str, Any], + battle_prompt: Dict[str, Any], + gpt_evaluation_prompt: Dict[str, Any], + gpt_model: str, + language: str, + path_for_UniEval: Dict[str, str], + gpt_with_reference: bool, + ) -> None: self.params = params self.battle_prompt = battle_prompt self.gpt_evaluation_prompt = gpt_evaluation_prompt @@ -103,7 +110,8 @@ def switch(metric, language): if self.params[category]["UniEval"] and self.language == "cn": raise Exception( - "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.") + "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file." + ) category_metrics = self.params[category]["UniEval"] @@ -134,10 +142,9 @@ def switch(metric, language): sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]] data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list) - scores = uni_evaluator.evaluate(data, - category, - dims=list(self.unieval_metric_stats[task][category].keys()), - overall=False) + scores = uni_evaluator.evaluate( + data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False + ) avg_scores = unieval.calculate_average_score(scores) self.unieval_metric_stats[task][category].update(avg_scores) @@ -165,7 +172,8 @@ def switch(metric, language): category, self.gpt_model, self.language, - references=targets_per_category[category] if self.gpt_with_reference else None) + references=targets_per_category[category] if self.gpt_with_reference else None, + ) def save(self, path: str, model_name_list: List[str]) -> None: """ @@ -204,16 +212,18 @@ def save(self, path: str, model_name_list: List[str]) -> None: gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") - all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0], - self.gpt_evaluation_results, - gpt_evaluation_results_save_path) + all_evaluations = gpt_evaluate.save_gpt_evaluation_results( + model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path + ) # Start to calculate scores and save statistics. gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") - gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations, - gpt_evaluation_statistics_save_path) + gpt_evaluate.save_gpt_evaluation_statistics( + model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path + ) # Save charts and csv. gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") - gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path, - gpt_evaluation_analyses_save_path) + gpt_evaluate.analyze_gpt_evaluation_statistics( + gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path + ) diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py index 6fcbe63d0253..ad908f4ba48c 100644 --- a/applications/Chat/evaluate/gpt_evaluate.py +++ b/applications/Chat/evaluate/gpt_evaluate.py @@ -14,20 +14,18 @@ from utils import jdump, jload ref_step_template = { - "en": - "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", - "cn": - "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n" + "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", + "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n", } ref_answer_template_general = { "en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n", - "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n" + "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n", } ref_answer_template_correctness = { "en": "\nA correct answer is as follows:\n\n{answer}\n\n", - "cn": "\n标准答案如下:\n\n{answer}\n\n" + "cn": "\n标准答案如下:\n\n{answer}\n\n", } @@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in response = openai.ChatCompletion.create( model="gpt-4", messages=[ - { - "role": "system", - "content": sys_prompt - }, + {"role": "system", "content": sys_prompt}, { "role": "user", "content": user_prompt, @@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]: return [float(sp[0]), float(sp[1])] else: raise Exception(f"Invalid score pair. Got {evaluation}.") - except Exception as e: + except Exception: return [-1, -1] @@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] assert len(answer1) == len(answer2) - handles = [] - evaluation_file = [] - total_len = len(answer1) question_idx_list = list(range(total_len)) @@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] assert answer1[i]["id"] == answer2[i]["id"] answer_id = answer1[i]["id"] - ques = answer1[i]["instruction"] if answer1[i][ - "input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"] - cat = answer1[i]["category"] + ques = ( + answer1[i]["instruction"] + if answer1[i]["input"] == "" + else answer1[i]["instruction"] + " " + answer1[i]["input"] + ) + answer1[i]["category"] ans1 = answer1[i]["output"] ans2 = answer2[i]["output"] @@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> step_to_add = ref_step_template[language] - for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)" + for_the_given_answer = ( + "{metric} (1-5) (directly give the score for the given answer):" + if language == "en" + else "{metric} (1-5) (直接对给定答案打分)" + ) # adjective is used to describe the word "answer" in the prompt. adjective = "example" if language == "en" else "示例" @@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> answer_to_add = ref_answer_template_correctness[language] answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"]) - step_to_add = step_to_add.format(metric=metric.lower(), - adjective=adjective) + for_the_given_answer.format(metric=metric) + step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format( + metric=metric + ) return answer_to_add + step_to_add @@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: for j in range(i): messages_to_send.append(fill_in_message("user", user_messages[j])) messages_to_send.append( - fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])) + fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]) + ) # Length of user messages == Length of assistant messages + 1 # Because we always expect the api to response @@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: return assistant_responses[-1] -def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], - inst: Dict[str, Any], - metrics: List[str], - language: str, - reference: Dict[str, Any] = None, - model: str = "gpt-3.5-turbo", - max_tokens: int = 2048) -> Dict[str, Any]: +def get_gpt_evaluation_without_logprobs( + prompt: Dict[str, Any], + inst: Dict[str, Any], + metrics: List[str], + language: str, + reference: Dict[str, Any] = None, + model: str = "gpt-3.5-turbo", + max_tokens: int = 2048, +) -> Dict[str, Any]: """ Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer. @@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], MAX_API_RETRY = 3 - question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] answer = inst["output"] inst["evaluation"] = {} @@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], if prompt_reference: # Do a 2-round conversation - response = multiturn_chat_completion([prompt_1st_round, prompt_reference], - model, - max_tokens=max_tokens, - turns=2) + response = multiturn_chat_completion( + [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2 + ) else: response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1) @@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], return inst -def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], - inst: Dict[str, Any], - metrics: List[str], - max_tokens: int = 2048) -> Dict[str, Any]: +def get_gpt_evaluation_with_logprobs( + prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048 +) -> Dict[str, Any]: """ Use completion model(text-davinci-003) to evaluate one model answer. Only completion models can return log probabilities. @@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], MAX_API_RETRY = 3 - question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] answer = inst["output"] inst["evaluation"] = {} @@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], return inst -def evaluate(answers: List[Dict], - prompt: Dict[str, Any], - metrics: List[str], - category: str, - model: str, - language: str, - references: List[Dict] = None) -> List[Dict]: +def evaluate( + answers: List[Dict], + prompt: Dict[str, Any], + metrics: List[str], + category: str, + model: str, + language: str, + references: List[Dict] = None, +) -> List[Dict]: """ Use GPT models to evaluate model answers and save evaluation results. @@ -529,21 +532,23 @@ def evaluate(answers: List[Dict], if model == "text-davinci-003": future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) else: - future = executor.submit(get_gpt_evaluation_without_logprobs, - prompt, - inst, - metrics, - language, - reference=None if references is None else references[idx], - model=model, - max_tokens=1) + future = executor.submit( + get_gpt_evaluation_without_logprobs, + prompt, + inst, + metrics, + language, + reference=None if references is None else references[idx], + model=model, + max_tokens=1, + ) futures.append(future) for future in tqdm.tqdm( - concurrent.futures.as_completed(futures), - desc=f"{category}: ", - total=len(futures), + concurrent.futures.as_completed(futures), + desc=f"{category}: ", + total=len(futures), ): evaluations.append(future.result()) @@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> return int(results[0]) else: raise Exception(f"Invalid score pair. Got {evaluation}.") - except Exception as e: + except Exception: return 0 -def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any], - save_path: str) -> Dict[str, Any]: +def save_gpt_evaluation_results( + model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str +) -> Dict[str, Any]: """ Save evaluation results for different categories for one model. @@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav scores[metric].append(0) elif evaluation["evaluation"][metric]["logprobs"] is not None: scores[metric].append( - calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])) + calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]) + ) else: scores[metric].append( - calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)) + calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation) + ) statistics = {} for metric in metrics: @@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv")) for category in tqdm.tqdm( - frame_per_category.keys(), - desc=f"GPT evaluation: ", - total=len(frame_per_category.keys()), + frame_per_category.keys(), + desc=f"GPT evaluation: ", + total=len(frame_per_category.keys()), ): data = pd.DataFrame(frame_per_category[category]) diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py index 77f9b6e98044..85ee4de53725 100644 --- a/applications/Chat/evaluate/metrics.py +++ b/applications/Chat/evaluate/metrics.py @@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, """ bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0} cumulative_bleu = [0] * 4 - weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.), - (1. / 4., 1. / 4., 1. / 4., 1. / 4.)] + weights = [ + (1.0 / 1.0, 0.0, 0.0, 0.0), + (1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0), + (1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0), + (1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0), + ] for pred, target in zip(preds, targets): if language == "cn": - pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() - target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()] + pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() + target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()] elif language == "en": pred_list = preprocessing_text(pred).split() target_list = [preprocessing_text(target).split()] @@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate CHRF Score Metric in sentence level. - """ + """Calculate CHRF Score Metric in sentence level.""" chrf_score = {"chrf": 0} cumulative_chrf = [] for pred, target in zip(preds, targets): if language == "cn": - pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() - target_list = ' '.join(jieba.cut(preprocessing_text(target))).split() + pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() + target_list = " ".join(jieba.cut(preprocessing_text(target))).split() elif language == "en": pred_list = preprocessing_text(pred).split() target_list = preprocessing_text(target).split() @@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]: all_targets = [] for pred, target in zip(preds, targets): - pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred)))) - target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target)))) + pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred)))) + target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target)))) all_preds.append(pred_list) all_targets.append(target_list) @@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]: longest common subsequence (LCS) between preds and targets. """ rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0} - all_preds = [] - all_targets = [] rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False) for pred, target in zip(preds, targets): score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target)) - rouge_scores["rouge1"] += score['rouge1'].fmeasure - rouge_scores["rouge2"] += score['rouge2'].fmeasure - rouge_scores["rougeL"] += score['rougeL'].fmeasure + rouge_scores["rouge1"] += score["rouge1"].fmeasure + rouge_scores["rouge2"] += score["rouge2"].fmeasure + rouge_scores["rougeL"] += score["rougeL"].fmeasure rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds) rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds) @@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: for pred in preds: if language == "cn": - pred_seg_list = ' '.join(jieba.cut(pred)).split() + pred_seg_list = " ".join(jieba.cut(pred)).split() count_segs = len(pred_seg_list) unique_segs = set(pred_seg_list) count_unique_chars = len(unique_segs) @@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: split_pred = preprocessing_text(pred).split() for n in range(0, 3): for i in range(0, len(split_pred) - n): - ngram = ' '.join(split_pred[i:i + n + 1]) + ngram = " ".join(split_pred[i : i + n + 1]) unique_ngram[n].add(ngram) all_ngram_count[n] += 1 @@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language for pred, target in zip(preds, targets): if language == "cn": - pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()] - target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()] + pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()] + target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()] elif language == "en": pred_list = [char for char in preprocessing_text(pred).split()] target_list = [char for char in preprocessing_text(target).split()] diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py index dad8d6ad09fa..6ffccdaa0819 100644 --- a/applications/Chat/evaluate/unieval/__init__.py +++ b/applications/Chat/evaluate/unieval/__init__.py @@ -7,6 +7,9 @@ ) __all__ = [ - 'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results', - 'analyze_unieval_results' + "get_evaluator", + "convert_data_to_unieval_format", + "calculate_average_score", + "save_unieval_results", + "analyze_unieval_results", ] diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py index 56cc6d2f9e41..bf2bc33a95c0 100644 --- a/applications/Chat/evaluate/unieval/evaluator.py +++ b/applications/Chat/evaluate/unieval/evaluator.py @@ -28,29 +28,29 @@ class SumEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for text summarization """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for text summarization""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'summarization' - self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance'] + cache_dir=cache_dir, + ) + self.task = "summarization" + self.dimensions = ["coherence", "consistency", "fluency", "relevance"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate - four dimensions: coherence, consistency, fluency, relevance. + dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate + four dimensions: coherence, consistency, fluency, relevance. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -63,12 +63,12 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: # Calculate average sentence-level scores for 'consistency' and 'fluency' - if dim == 'consistency' or dim == 'fluency': + if dim == "consistency" or dim == "fluency": src_list, output_list = [], [] - n_sents = [] # the number of sentences in each generated summary + n_sents = [] # the number of sentences in each generated summary for i in range(n_data): - source = data[i]['source'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) @@ -81,24 +81,26 @@ def evaluate(self, data, category, dims=None, overall=True): score = [] for cur_n_sent in n_sents: # prevent denominator from being 0 - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) start_idx += cur_n_sent # Calculate summary-level score for 'coherence' and 'relevance' - elif dim == 'coherence' or dim == 'relevance': + elif dim == "coherence" or dim == "relevance": src_list, output_list, ref_list = [], [], [] for i in range(n_data): - src_list.append(data[i]['source']) - output_list.append(data[i]['system_output']) - if dim == 'relevance': - ref_list.append(data[i]['reference']) + src_list.append(data[i]["source"]) + output_list.append(data[i]["system_output"]) + if dim == "relevance": + ref_list.append(data[i]["reference"]) input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task) score = self.scorer.score(input_list, self.task, category, dim) # Please customize other dimensions here for summarization else: - raise NotImplementedError('The input format for this dimension is still undefined. \ - Please customize it first.') + raise NotImplementedError( + "The input format for this dimension is still undefined. \ + Please customize it first." + ) for i in range(n_data): eval_scores[i][dim] = score[i] @@ -106,35 +108,35 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class DialogEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for dialogues """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for dialogues""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'dialogue' - self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability'] + cache_dir=cache_dir, + ) + self.task = "dialogue" + self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate - five dimensions: naturalness, coherence, engagingness, groundedness and understandability. + dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate + five dimensions: naturalness, coherence, engagingness, groundedness and understandability. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -147,50 +149,48 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: # Calculate summation score for 'engagingness' - if dim == 'engagingness': + if dim == "engagingness": src_list, output_list, context_list = [], [], [] - n_sents = [] # the number of sentences in each generated response + n_sents = [] # the number of sentences in each generated response for i in range(n_data): - source = data[i]['source'] - context = data[i]['context'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + context = data[i]["context"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) context_list.append(context) output_list.append(system_outputs[j]) - input_list = add_question(dimension=dim, - output=output_list, - src=src_list, - context=context_list, - task=self.task) + input_list = add_question( + dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task + ) sent_score = self.scorer.score(input_list, self.task, category, dim) # Get the summation score for each sample start_idx = 0 score = [] for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent])) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent])) start_idx += cur_n_sent # Calculate turn-level score for other dimensions - elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']: + elif dim in ["naturalness", "coherence", "groundedness", "understandability"]: src_list, output_list, context_list = [], [], [] for i in range(n_data): - src_list.append(data[i]['source']) - output_list.append(data[i]['system_output']) - context_list.append(data[i]['context']) - input_list = add_question(dimension=dim, - output=output_list, - src=src_list, - context=context_list, - task=self.task) + src_list.append(data[i]["source"]) + output_list.append(data[i]["system_output"]) + context_list.append(data[i]["context"]) + input_list = add_question( + dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task + ) score = self.scorer.score(input_list, self.task, category, dim) # Please customize other dimensions here for summarization else: - raise NotImplementedError('The input format for this dimension is still undefined. \ - Please customize it first.') + raise NotImplementedError( + "The input format for this dimension is still undefined. \ + Please customize it first." + ) for i in range(n_data): eval_scores[i][dim] = score[i] @@ -198,35 +198,35 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class D2tEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for data-to-text """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for data-to-text""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'data2text' - self.dimensions = ['naturalness', 'informativeness'] + cache_dir=cache_dir, + ) + self.task = "data2text" + self.dimensions = ["naturalness", "informativeness"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate - two dimensions: naturalness and informativeness. + dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate + two dimensions: naturalness and informativeness. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -240,8 +240,8 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: output_list, ref_list = [], [] for i in range(n_data): - output_list.append(data[i]['system_output']) - ref_list.append(data[i]['reference']) + output_list.append(data[i]["system_output"]) + ref_list.append(data[i]["reference"]) input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task) score = self.scorer.score(input_list, self.task, category, dim) @@ -252,38 +252,38 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class FactEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for factual consistency detection """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for factual consistency detection""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'fact' - self.dim = 'consistency' + cache_dir=cache_dir, + ) + self.task = "fact" + self.dim = "consistency" def evaluate(self, data, category): """ - Get the factual consistency score (only 1 dimension for this task) + Get the factual consistency score (only 1 dimension for this task) - category: The category to be evaluated. + category: The category to be evaluated. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] # Calculate average sentence-level scores for factual consistency src_list, output_list = [], [] - n_sents = [] # the number of sentences in the claim + n_sents = [] # the number of sentences in the claim for i in range(n_data): - source = data[i]['source'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) @@ -295,7 +295,7 @@ def evaluate(self, data, category): start_idx = 0 score = [] for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent) start_idx += cur_n_sent for i in range(n_data): @@ -304,28 +304,26 @@ def evaluate(self, data, category): return eval_scores -def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None): - assert task in ['summarization', 'dialogue', 'data2text', 'fact'] - if task == 'summarization': - return SumEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'dialogue': - return DialogEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'data2text': - return D2tEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'fact': - return FactEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) +def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None): + assert task in ["summarization", "dialogue", "data2text", "fact"] + if task == "summarization": + return SumEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "dialogue": + return DialogEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "data2text": + return D2tEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "fact": + return FactEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) else: - raise NotImplementedError('Other tasks are not implemented, \ - please customize specific tasks here.') + raise NotImplementedError( + "Other tasks are not implemented, \ + please customize specific tasks here." + ) diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py index 2c70bb9f6ded..45706b833205 100644 --- a/applications/Chat/evaluate/unieval/scorer.py +++ b/applications/Chat/evaluate/unieval/scorer.py @@ -27,9 +27,8 @@ class UniEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up model """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up model""" self.device = device self.max_length = max_length @@ -47,8 +46,8 @@ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_d def score(self, inputs, task, category, dim, batch_size=8): """ - Get scores for the given samples. - final_score = postive_score / (postive_score + negative_score) + Get scores for the given samples. + final_score = postive_score / (postive_score + negative_score) """ # The implementation of "forward" in T5 still requires decoder_input_ids. @@ -58,31 +57,27 @@ def score(self, inputs, task, category, dim, batch_size=8): pos_score_list, neg_score_list = [], [] for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "): - src_list = inputs[i:i + batch_size] - tgt_list = tgts[i:i + batch_size] + src_list = inputs[i : i + batch_size] + tgt_list = tgts[i : i + batch_size] try: with torch.no_grad(): - encoded_src = self.tokenizer(src_list, - max_length=self.max_length, - truncation=True, - padding=True, - return_tensors='pt') - encoded_tgt = self.tokenizer(tgt_list, - max_length=self.max_length, - truncation=True, - padding=True, - return_tensors='pt') - - src_tokens = encoded_src['input_ids'].to(self.device) - src_mask = encoded_src['attention_mask'].to(self.device) - - tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1) + encoded_src = self.tokenizer( + src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + encoded_tgt = self.tokenizer( + tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + src_tokens = encoded_src["input_ids"].to(self.device) + src_mask = encoded_src["attention_mask"].to(self.device) + + tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1) output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens) logits = output.logits.view(-1, self.model.config.vocab_size) - pos_score = self.softmax(logits)[:, self.pos_id] # Yes - neg_score = self.softmax(logits)[:, self.neg_id] # No + pos_score = self.softmax(logits)[:, self.pos_id] # Yes + neg_score = self.softmax(logits)[:, self.neg_id] # No cur_pos_score = [x.item() for x in pos_score] cur_neg_score = [x.item() for x in neg_score] @@ -90,8 +85,8 @@ def score(self, inputs, task, category, dim, batch_size=8): neg_score_list += cur_neg_score except RuntimeError: - print(f'source: {src_list}') - print(f'target: {tgt_list}') + print(f"source: {src_list}") + print(f"target: {tgt_list}") exit(0) score_list = [] diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py index a381e9e590b2..46b0f2907a30 100644 --- a/applications/Chat/evaluate/unieval/utils.py +++ b/applications/Chat/evaluate/unieval/utils.py @@ -31,105 +31,142 @@ def add_question(dimension, output, src=None, ref=None, context=None, task=None): """ - Add questions to generate input in Bool-QA format for UniEval. - - dimension: specific dimension to be evaluated - src: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation. - output: output text generated by the models - ref: human-annotated groundtruth - context: the context needed to evaluate several specific dimension. For example, - additional factual information when evaluating engagingness and groundedness in dialogues. + Add questions to generate input in Bool-QA format for UniEval. + + dimension: specific dimension to be evaluated + src: source input for different NLG tasks. For example, source document for summarization + and dialogue history for dialogue response generation. + output: output text generated by the models + ref: human-annotated groundtruth + context: the context needed to evaluate several specific dimension. For example, + additional factual information when evaluating engagingness and groundedness in dialogues. """ input_with_question = [] for i in range(len(output)): # For summarization - if task == 'summarization': - if dimension == 'fluency': - cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i] - elif dimension == 'coherence': - cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[ - i] + ' document: ' + src[i] - elif dimension == 'consistency': - cur_input = 'question: Is this claim consistent with the document? claim: ' + output[ - i] + ' document: ' + src[i] - elif dimension == 'relevance': - cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[ - i] + ' reference: ' + ref[i] + if task == "summarization": + if dimension == "fluency": + cur_input = "question: Is this a fluent paragraph? paragraph: " + output[i] + elif dimension == "coherence": + cur_input = ( + "question: Is this a coherent summary to the document? summary: " + + output[i] + + " document: " + + src[i] + ) + elif dimension == "consistency": + cur_input = ( + "question: Is this claim consistent with the document? claim: " + + output[i] + + " document: " + + src[i] + ) + elif dimension == "relevance": + cur_input = ( + "question: Is this summary relevant to the reference? summary: " + + output[i] + + " reference: " + + ref[i] + ) else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For dialogues - elif task == 'dialogue': - if dimension == 'naturalness': - cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i] - elif dimension == 'coherence': - cur_input = 'question: Is this a coherent response given the dialogue history? response: '\ - + output[i] + ' dialogue history: ' + src[i] - elif dimension == 'engagingness': - cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\ - + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i] - elif dimension == 'groundedness': - cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\ - + output[i] + ' fact: ' + context[i] - elif dimension == 'understandability': - cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i] + elif task == "dialogue": + if dimension == "naturalness": + cur_input = "question: Is this a natural response in the dialogue? response: " + output[i] + elif dimension == "coherence": + cur_input = ( + "question: Is this a coherent response given the dialogue history? response: " + + output[i] + + " dialogue history: " + + src[i] + ) + elif dimension == "engagingness": + cur_input = ( + "question: Is this an engaging and informative response according to the dialogue history and fact? response: " + + output[i] + + " dialogue history: " + + src[i] + + " fact: " + + context[i] + ) + elif dimension == "groundedness": + cur_input = ( + "question: Is this response consistent with knowledge in the fact? response: " + + output[i] + + " fact: " + + context[i] + ) + elif dimension == "understandability": + cur_input = "question: Is this an understandable response in the dialogue? response: " + output[i] else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For data-to-text - elif task == 'data2text': - if dimension == 'naturalness': - cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i] - elif dimension == 'informativeness': - cur_input = 'question: Is this sentence informative according to the reference? sentence: '\ - + output[i] + ' reference: ' + ref[i] + elif task == "data2text": + if dimension == "naturalness": + cur_input = "question: Is this a fluent utterance? utterance: " + output[i] + elif dimension == "informativeness": + cur_input = ( + "question: Is this sentence informative according to the reference? sentence: " + + output[i] + + " reference: " + + ref[i] + ) else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For factual consistency detection - elif task == 'fact': - if dimension == 'consistency': - cur_input = 'question: Is this claim consistent with the document? claim: ' + output[ - i] + ' document: ' + src[i] + elif task == "fact": + if dimension == "consistency": + cur_input = ( + "question: Is this claim consistent with the document? claim: " + + output[i] + + " document: " + + src[i] + ) else: - raise NotImplementedError('No other dimensions for the factual consistency detection task.') + raise NotImplementedError("No other dimensions for the factual consistency detection task.") # For new customized tasks else: - raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.') + raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.") input_with_question.append(cur_input) return input_with_question def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None): """ - Convert the data into the unieval's format. + Convert the data into the unieval's format. - output_list: a list of model output + output_list: a list of model output - src_list: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation - ref_list: human-annotated groundtruth + src_list: source input for different NLG tasks. For example, source document for summarization + and dialogue history for dialogue response generation + ref_list: human-annotated groundtruth """ json_data = [] for i in range(len(output_list)): cur = {} - cur['system_output'] = output_list[i] + cur["system_output"] = output_list[i] if src_list is not None: - cur['source'] = src_list[i] + cur["source"] = src_list[i] if ref_list is not None: - cur['reference'] = ref_list[i] - cur['context'] = "" + cur["reference"] = ref_list[i] + cur["context"] = "" json_data.append(cur) return json_data def calculate_average_score(scores): """ - Calculate average scores for different metrics + Calculate average scores for different metrics - scores: a list of scores for different metrics for each answer + scores: a list of scores for different metrics for each answer """ metrics = {metric: 0 for metric in scores[0]} @@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None: frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv")) for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"UniEval metrics: ", - total=len(frame_per_metric.keys()), + frame_per_metric.keys(), + desc=f"UniEval metrics: ", + total=len(frame_per_metric.keys()), ): data = pd.DataFrame(frame_per_metric[metric]) diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py index 406e43db99aa..10df455b69d7 100644 --- a/applications/Chat/evaluate/utils.py +++ b/applications/Chat/evaluate/utils.py @@ -1,7 +1,6 @@ import io import json import os -import re import string from typing import Dict @@ -55,7 +54,7 @@ def jload(f, mode="r"): def get_json_list(file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: json_list = [] for line in f: json_list.append(json.loads(line)) @@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None: frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv")) for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"automatic metrics: ", - total=len(frame_per_metric.keys()), + frame_per_metric.keys(), + desc=f"automatic metrics: ", + total=len(frame_per_metric.keys()), ): data = pd.DataFrame(frame_per_metric[metric]) diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py index 2fe293957079..d4b17689e9cb 100644 --- a/applications/Chat/examples/community/peft/easy_dataset.py +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -3,7 +3,6 @@ from typing import Dict, Sequence import torch -from datasets import load_dataset from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer @@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i padding="longest", max_length=max_length, truncation=True, - ) for text in strings + ) + for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ @@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo class EasySupervisedDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: super(EasySupervisedDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" sources, targets = [], [] for line in all_lines: if "回答:" in line: sep_index = line.index("回答:") - sources.append(line[:sep_index + 3]) - targets.append(line[sep_index + 3:] + tokenizer.eos_token) + sources.append(line[: sep_index + 3]) + targets.append(line[sep_index + 3 :] + tokenizer.eos_token) else: sources.append(line) targets.append("" + tokenizer.eos_token) @@ -83,15 +82,17 @@ def __str__(self): class EasyPromptsDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: super(EasyPromptsDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] + all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines] self.prompts = [ - tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', - truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[ + "input_ids" + ] + .to(torch.cuda.current_device()) + .squeeze(0) for line in tqdm(all_lines) ] self.data_file = data_file @@ -110,7 +111,6 @@ def __str__(self): class EasyRewardDataset(Dataset): - def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: super(EasyRewardDataset, self).__init__() self.chosen = [] @@ -120,44 +120,42 @@ def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None else: self.end_token = special_token print(self.end_token) - #read all lines in the train_file to a list + # read all lines in the train_file to a list with open(train_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() for line in tqdm(all_lines): data = json.loads(line) - prompt = "提问:" + data['prompt'] + " 回答:" - - chosen = prompt + data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = prompt + data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + prompt = "提问:" + data["prompt"] + " 回答:" + + chosen = prompt + data["chosen"] + self.end_token + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen.append( + {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + ) + + reject = prompt + data["rejected"] + self.end_token + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject.append( + {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} + ) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] - - #python representation of the object and the string representation of the object + return ( + self.chosen[idx]["input_ids"], + self.chosen[idx]["attention_mask"], + self.reject[idx]["input_ids"], + self.reject[idx]["attention_mask"], + ) + + # python representation of the object and the string representation of the object def __repr__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" @@ -165,26 +163,25 @@ def __str__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" -''' +""" Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. If individual lines are not related, just set is_group_texts to False. -''' +""" class EasySFTDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: super().__init__() - #read the data_file line by line + # read the data_file line by line with open(data_file, "r", encoding="UTF-8") as f: - #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + # encode the text data line by line and put raw python list input_ids only to raw_input_ids list raw_input_ids = [] for line in f: encoded_ids = tokenizer.encode(line) - #if the encoded_ids is longer than max_length, then split it into several parts + # if the encoded_ids is longer than max_length, then split it into several parts if len(encoded_ids) > max_length: for i in range(0, len(encoded_ids), max_length): - raw_input_ids.append(encoded_ids[i:i + max_length]) + raw_input_ids.append(encoded_ids[i : i + max_length]) else: raw_input_ids.append(encoded_ids) @@ -196,12 +193,13 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ if is_group_texts: for input_ids in raw_input_ids: if len(current_input_ids) + len(input_ids) > max_length: - #pad the current_input_ids to max_length with tokenizer.pad_token_id + # pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) current_input_ids = [] else: current_input_ids.extend(input_ids) @@ -210,14 +208,16 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) else: - #just append the raw_input_ids to max_length + # just append the raw_input_ids to max_length for input_ids in raw_input_ids: padded_length = max_length - len(input_ids) input_ids.extend([tokenizer.pad_token_id] * padded_length) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) self.input_ids = grouped_input_ids self.labels = copy.deepcopy(self.input_ids) @@ -227,14 +227,14 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ def __len__(self): return len(self.input_ids) - #get item from dataset + # get item from dataset def __getitem__(self, idx): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __repr__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __str__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py index fe294868159d..db629e50ed94 100644 --- a/applications/Chat/examples/community/peft/easy_models.py +++ b/applications/Chat/examples/community/peft/easy_models.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from coati.models.generation import generate -from coati.models.utils import log_probs_from_logits, masked_mean +from coati.models.utils import log_probs_from_logits from peft import PeftModel from torch.nn.modules import Module from transformers import BloomConfig, BloomForCausalLM @@ -24,38 +24,33 @@ def __init__(self, model: nn.Module) -> None: @torch.no_grad() def generate( - self, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs + self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: sequences = generate(self.model, input_ids, **kwargs) attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) + pad_token_id = kwargs.get("pad_token_id", None) if pad_token_id is not None: attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) if not return_action_mask: return sequences, attention_mask, None input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) + eos_token_id = kwargs.get("eos_token_id", None) if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: # left padding may be applied, only mask action action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask[:, :input_len] = False action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :] - def forward(self, - sequences: torch.LongTensor, - num_actions: int, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """Returns action log probs - """ + def forward( + self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns action log probs""" output = self.model(sequences, attention_mask=attention_mask) - logits = output['logits'] + logits = output["logits"] log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] @@ -75,11 +70,13 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_path: str = None) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None, + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index 9385e457d852..e49db1d2bc1b 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -1,18 +1,16 @@ import argparse -import pandas as pd import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.dataset import DataCollatorForSupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMCritic -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.gpt import GPTRM, GPTCritic +from coati.models.llama import LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from easy_dataset import EasyPromptsDataset, EasySupervisedDataset from easy_models import BLOOMActor -from peft import PeftModel from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -23,24 +21,24 @@ def main(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - state_dict = torch.load(args.rm_path, map_location='cpu') + state_dict = torch.load(args.rm_path, map_location="cpu") # configure model - if args.model == 'bloom': + if args.model == "bloom": # initial_model = BLOOMActor(pretrained=args.pretrain) - print('Using peft lora to load Bloom model as initial_model') + print("Using peft lora to load Bloom model as initial_model") initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as initial_model (Done)') + print("Using peft lora to load Bloom model as initial_model (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -49,59 +47,59 @@ def main(args): else: rm_model_name = args.rm_model - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom reward model ", args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": reward_model = LlamaRM(pretrained=args.rm_pretrain) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) reward_model.load_state_dict(state_dict) - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) with strategy.model_init_context(): - if args.model == 'bloom': + if args.model == "bloom": # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - print('Using peft lora to load Bloom model as Actor') + print("Using peft lora to load Bloom model as Actor") actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as Actor (Done)') + print("Using peft lora to load Bloom model as Actor (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) print("load bloom critic (Done) ") - elif rm_model_name == 'opt': + elif rm_model_name == "opt": critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) critic.load_state_dict(state_dict) del state_dict - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7) else: @@ -109,18 +107,18 @@ def main(args): critic_optim = Adam(critic.parameters(), lr=1e-7) # configure tokenizer - if args.model == 'gpt2': + if args.model == "gpt2": tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -132,26 +130,27 @@ def main(args): prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.train_batch_size) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size + ) pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size, - collate_fn=data_collator) + pretrain_dataloader = DataLoader( + pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator, + ) def tokenize_fn(texts): # MUST padding to max length to ensure inputs of all ranks have the same length # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) @@ -178,45 +177,46 @@ def tokenize_fn(texts): eos_token_id=tokenizer.eos_token_id, ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - num_update_steps=args.num_update_steps, - num_collect_steps=args.num_collect_steps) + trainer.fit( + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) # save model checkpoint after fitting trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--sft_lora_path', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--num_collect_steps', type=int, default=10) - parser.add_argument('--num_update_steps', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=2) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) + parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use" + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--sft_lora_path", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=2) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py index 4af08e6d0141..0b62dd652adb 100644 --- a/applications/Chat/examples/community/peft/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -1,18 +1,10 @@ import argparse import os -import loralib as lora import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset -from coati.models.base import RewardModel -from coati.models.bloom import BLOOMLM -from coati.models.gpt import GPTLM -from coati.models.llama import LlamaLM -from coati.models.opt import OPTLM from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from datasets import load_dataset from easy_dataset import EasyDataset from peft import LoraConfig, PeftModel, TaskType, get_peft_model from torch.optim import Adam @@ -29,75 +21,76 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): - print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested") model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json - if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \ - and os.path.exists(args.save_path + '/adapter_model.bin'): + if ( + os.path.exists(args.save_path) + and os.path.exists(args.save_path + "/adapter_config.json") + and os.path.exists(args.save_path + "/adapter_model.bin") + ): print("loading from saved peft model ", args.save_path) model = PeftModel.from_pretrained(model, args.save_path) else: # we'll use peft lora library to do the lora lora_rank = args.lora_rank if args.lora_rank > 0 else 32 # config lora with rank of lora_rank - lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=lora_rank, - lora_alpha=32, - lora_dropout=0.1) + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1 + ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = AutoTokenizer.from_pretrained( args.pretrain, padding_side="right", use_fast=False, ) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama' and args.strategy == 'colossalai_gemini': + if args.model == "llama" and args.strategy == "colossalai_gemini": # this is a hack to deal with the resized embedding # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] + sub_module_name = ".".join(name.split(".")[:-1]) + weight_name = name.split(".")[-1] sub_module = model.get_submodule(sub_module_name) setattr(sub_module, weight_name, ColoParameter(param)) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) logger = get_dist_logger() - logger.set_level('WARNING') + logger.set_level("WARNING") # configure dataset law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) @@ -108,47 +101,57 @@ def train(args): eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) data_collator = default_collate if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) else: eval_dataloader = None - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - batch_size=args.batch_size, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) trainer.fit(logger=logger, log_interval=args.log_interval) @@ -156,29 +159,27 @@ def train(args): trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--eval_dataset', type=str, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--enable_peft_lora', action='store_true', default=False) - parser.add_argument("--is_short_text", action='store_true', default=False) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--eval_dataset", type=str, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--enable_peft_lora", action="store_true", default=False) + parser.add_argument("--is_short_text", action="store_true", default=False) args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py index 53f304d379fe..e8a1175a9c32 100644 --- a/applications/Chat/examples/community/ray/ray_job_script.py +++ b/applications/Chat/examples/community/ray/ray_job_script.py @@ -6,16 +6,25 @@ def main(api_server_endpoint="http://127.0.0.1:8265"): client = JobSubmissionClient(api_server_endpoint) client.submit_job( - entrypoint= - "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", + entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", runtime_env={ - "working_dir": - "applications/Chat", + "working_dir": "applications/Chat", "pip": [ - "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain", - "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat" - ] - }) + "torch==1.13.1", + "transformers>=4.20.1", + "datasets", + "loralib", + "colossalai>=0.2.4", + "langchain", + "tokenizers", + "fastapi", + "sse_starlette", + "wandb", + "sentencepiece", + "gpustat", + ], + }, + ) if __name__ == "__main__": diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py index 1bba9ad66fbc..8abd83a8b249 100644 --- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py +++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py @@ -26,9 +26,14 @@ class ExperienceCompositionRefs: - - def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef, - base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None: + def __init__( + self, + sequences_attention_mask_action_mask_ref: ray.ObjectRef, + action_log_probs_ref: ray.ObjectRef, + base_action_log_probs_ref: ray.ObjectRef, + value_ref: ray.ObjectRef, + r_ref: ray.ObjectRef, + ) -> None: self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref self.action_log_probs_ref = action_log_probs_ref self.base_action_log_probs_ref = base_action_log_probs_ref @@ -37,14 +42,14 @@ def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, acti class ExperienceMaker: - def __init__(self, kl_coef) -> None: self.kl_coef = kl_coef @torch.no_grad() def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs): sequences, attention_mask, action_mask = ray.get( - experiment_computation_refs.sequences_attention_mask_action_mask_ref) + experiment_computation_refs.sequences_attention_mask_action_mask_ref + ) action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref) base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref) r = ray.get(experiment_computation_refs.r_ref) @@ -58,11 +63,10 @@ def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs class DistributedTorchRayActor: - def __init__(self, world_size, rank, local_rank, master_addr, master_port): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) self._model = None self._world_size = world_size self._rank = rank @@ -82,7 +86,7 @@ def _get_current_node_ip(): @staticmethod def _get_free_port(): with socket.socket() as sock: - sock.bind(('', 0)) + sock.bind(("", 0)) return sock.getsockname()[1] def get_master_addr_port(self): @@ -90,7 +94,6 @@ def get_master_addr_port(self): class BasePPORole(DistributedTorchRayActor): - def add_experience_maker(self, kl_coef: float = 0.1): self._experience_maker = ExperienceMaker(kl_coef) @@ -99,12 +102,12 @@ def make_experience(self, experience_computation_ref: ExperienceCompositionRefs) def _init_strategy(self, strategy: str): # configure strategy - if strategy == 'ddp': + if strategy == "ddp": self._strategy = DDPStrategy() - elif strategy == 'colossalai_gemini': - self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif strategy == "colossalai_gemini": + self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy == "colossalai_zero2": + self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{strategy}"') @@ -124,11 +127,9 @@ def _prepare_model_with_strategy(self, has_optimizer: bool): def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str): raise NotImplementedError() - def init_model_from_pretrained(self, - strategy: str, - model_class: Type[LoRAModule], - pretrain: str, - has_optimizer=False): + def init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False + ): self._init_strategy(strategy) self._load_model_from_pretrained(model_class, pretrain) self._prepare_model_with_strategy(has_optimizer) @@ -138,7 +139,6 @@ def eval(self): class TrainablePPORole(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -161,38 +161,39 @@ def learn_on_experiences(self, experience_refs): @ray.remote(num_gpus=1) class RayPPOActor(TrainablePPORole): - def set_loss_function(self, eps_clip: float): self._actor_loss_fn = PolicyLoss(eps_clip) def load_tokenizer_from_pretrained(self, model_type: str, pretrained): - if model_type == 'gpt2': + if model_type == "gpt2": self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'bloom': + elif model_type == "bloom": self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'opt': + elif model_type == "opt": self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained) else: raise ValueError(f'Unsupported model "{model_type}"') # Set tokenize function for sequence generation def _text_input_tokenize_fn(texts): - batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) + batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True) return {k: v.cuda() for k, v in batch.items()} self._sample_tokenize_function = _text_input_tokenize_fn def setup_generate_kwargs(self, generate_kwargs: dict): from coati.trainer.ppo import _set_default_generate_kwargs + self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model) - self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id - self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id + self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id + self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id def load_csv_prompt_file_from_url_to_sampler(self, prompt_url): import pandas as pd - prompts = pd.read_csv(prompt_url)['prompt'] + + prompts = pd.read_csv(prompt_url)["prompt"] self._sampler = self._strategy.setup_sampler(prompts) def _generate(self, input_ids, **generate_kwargs): @@ -214,10 +215,9 @@ def calculate_action_log_probs(self, sequence_attention_action_mask): def _training_step(self, experience): num_actions = experience.action_mask.size(1) action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self._actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self._actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) self._strategy.backward(actor_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -229,17 +229,18 @@ def save_checkpoint(self, save_path, should_save_optimizer: bool): self._strategy.save_model(self._model, save_path, only_rank0=True) # save optimizer checkpoint on all ranks if should_save_optimizer: - self._strategy.save_optimizer(self._optimizer, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + self._strategy.save_optimizer( + self._optimizer, + "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), + only_rank0=False, + ) def generate_answer(self, prompt, max_length=30, num_return_sequences=5): - encoded_input = self._model_tokenizer(prompt, return_tensors='pt') + encoded_input = self._model_tokenizer(prompt, return_tensors="pt") input_ids = {k: v.cuda() for k, v in encoded_input.items()} - sequence, _ = self._model.generate(**input_ids, - max_length=max_length, - return_action_mask=False, - num_return_sequences=num_return_sequences) + sequence, _ = self._model.generate( + **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences + ) token_list = list(sequence.data[0]) output = " ".join([self._model_tokenizer.decode(token) for token in token_list]) return output @@ -247,18 +248,16 @@ def generate_answer(self, prompt, max_length=30, num_return_sequences=5): @ray.remote(num_gpus=1) class RayPPOCritic(TrainablePPORole): - def set_loss_function(self, value_clip: float): self._critic_loss_fn = ValueLoss(value_clip) def _training_step(self, experience): - values = self._model(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self._critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self._model( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self._critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) self._strategy.backward(critic_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -272,12 +271,12 @@ def calculate_value(self, sequence_attention_action_mask): @ray.remote(num_gpus=1) class RayPPORewardModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): critic = model_class(pretrained=pretrain).to(torch.cuda.current_device()) - self._model = RewardModel(deepcopy(critic.model), - deepcopy(critic.value_head)).to(torch.cuda.current_device()) + self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to( + torch.cuda.current_device() + ) @torch.no_grad() def calculate_r(self, sequence_attention_action_mask): @@ -287,7 +286,6 @@ def calculate_r(self, sequence_attention_action_mask): @ray.remote(num_gpus=1) class RayPPOInitialModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -300,8 +298,8 @@ def calculate_base_action_log_probs(self, sequence_attention_action_mask): class PPORayActorGroup: """ - A group of ray actors - Functions start with 'async' should return list of object refs + A group of ray actors + Functions start with 'async' should return list of object refs """ def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None: @@ -319,8 +317,9 @@ def _initiate_actors(self): pg = placement_group(bundles, strategy="STRICT_SPREAD") ray.get(pg.ready()) if pg: - master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None) + master_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0) + ).remote(world_size, 0, 0, None, None) else: master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None) self._actor_handlers = [master_actor] @@ -331,16 +330,20 @@ def _initiate_actors(self): for rank in range(1, world_size): local_rank = rank % self._num_gpus_per_node if pg: - worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote( - world_size, rank, local_rank, master_addr, master_port) + worker_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node + ) + ).remote(world_size, rank, local_rank, master_addr, master_port) else: - worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank, - master_addr, master_port) + worker_actor = self.ray_actor_type.options(num_gpus=1).remote( + world_size, rank, local_rank, master_addr, master_port + ) self._actor_handlers.append(worker_actor) - def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str, - has_optimizer: bool): + def async_init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool + ): return [ actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer) for actor in self._actor_handlers @@ -348,7 +351,6 @@ def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRA class TrainableModelRayActorGroup(PPORayActorGroup): - def async_learn_on_experiences(self, experience_refs): num_actors = len(self._actor_handlers) learn_result_refs = [] @@ -359,7 +361,6 @@ def async_learn_on_experiences(self, experience_refs): class PPOActorRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOActor) @@ -381,7 +382,8 @@ def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_ action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) action_log_probs_refs.append(action_log_probs_ref) return action_log_probs_refs @@ -393,7 +395,6 @@ def save_checkpoint(self, save_path, should_save_optimizer): class PPOCriticRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic) @@ -402,7 +403,8 @@ def async_calculate_value(self, sequences_attention_mask_action_mask_refs): value_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): value_ref = self._actor_handlers[i % num_actors].calculate_value.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) value_refs.append(value_ref) return value_refs @@ -411,7 +413,6 @@ def set_loss_function(self, value_clip: float = 0.4): class PPOInitialRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel) @@ -420,13 +421,13 @@ def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_ base_action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) base_action_log_probs_refs.append(base_action_log_probs_ref) return base_action_log_probs_refs class PPORewardRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel) @@ -435,20 +436,21 @@ def async_calculate_r(self, sequences_attention_mask_action_mask_refs): r_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): r_ref = self._actor_handlers[i % num_actors].calculate_r.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) r_refs.append(r_ref) return r_refs def main(args): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') - if args.model == 'gpt2': + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) + if args.model == "gpt2": actor_model_class, critic_model_class = GPTActor, GPTCritic - elif args.model == 'bloom': + elif args.model == "bloom": actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic - elif args.model == 'opt': + elif args.model == "opt": actor_model_class, critic_model_class = OPTActor, OPTCritic else: raise ValueError(f'Unsupported model "{args.model}"') @@ -462,13 +464,14 @@ def main(args): logging.info("Actors created") # Prepare model for training - generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50} + generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50} ray.get( - actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + - critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + - initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + - reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + - actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)) + actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs) + ) logging.info("Models prepared for training") # Prepare models for training @@ -483,8 +486,12 @@ def main(args): # Start training logging.info("Training start") # Set all models to eval and add experience maker - all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \ - initial_group._actor_handlers + reward_group._actor_handlers + all_ray_actors = ( + actor_group._actor_handlers + + critic_group._actor_handlers + + initial_group._actor_handlers + + reward_group._actor_handlers + ) num_ray_actors = len(all_ray_actors) ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors]) ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors]) @@ -497,18 +504,28 @@ def main(args): time += 1 # Experience queueing stage sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence( - experience_batch_size) + experience_batch_size + ) base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs( - sequences_attention_mask_action_mask_refs) + sequences_attention_mask_action_mask_refs + ) values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs) r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs) action_log_probs_refs = actor_group.async_calculate_action_log_probs( - sequences_attention_mask_action_mask_refs) - experience_composition_refs.extend([ - ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i], - base_action_log_probs_refs[i], values_refs[i], r_refs[i]) - for i in range(len(sequences_attention_mask_action_mask_refs)) - ]) + sequences_attention_mask_action_mask_refs + ) + experience_composition_refs.extend( + [ + ExperienceCompositionRefs( + sequences_attention_mask_action_mask_refs[i], + action_log_probs_refs[i], + base_action_log_probs_refs[i], + values_refs[i], + r_refs[i], + ) + for i in range(len(sequences_attention_mask_action_mask_refs)) + ] + ) # Learning stage if time % update_timesteps == 0: experience_refs = [] @@ -519,8 +536,9 @@ def main(args): experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref)) # backward ray.get( - actor_group.async_learn_on_experiences(experience_refs) + - critic_group.async_learn_on_experiences(experience_refs)) + actor_group.async_learn_on_experiences(experience_refs) + + critic_group.async_learn_on_experiences(experience_refs) + ) # clear refs queue experience_composition_refs.clear() logging.info("Training finished") @@ -528,26 +546,24 @@ def main(args): actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_csv_url', type=str) - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default='gpt2') - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1) - parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1) - parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1) - parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1) - parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1) + parser.add_argument("--prompt_csv_url", type=str) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"]) + parser.add_argument("--pretrain", type=str, default="gpt2") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--max_timesteps", type=int, default=10) + parser.add_argument("--update_timesteps", type=int, default=10) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1) + parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1) + parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1) + parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1) + parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1) args = parser.parse_args() ray.init() main(args) diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py index c2b5f9a859a9..ec3482b5f789 100644 --- a/applications/Chat/examples/download_model.py +++ b/applications/Chat/examples/download_model.py @@ -22,7 +22,7 @@ def download(self, dir_path: str): file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) def download_all(self): - file_path = snapshot_download(self.repo_id) + snapshot_download(self.repo_id) def test_init(model: str, dir_path: str): @@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str): actor = GPTActor(config=config) critic = GPTCritic(config=config) reward_model = GPTRM(config=config) - tokenizer = GPT2Tokenizer.from_pretrained(dir_path) + GPT2Tokenizer.from_pretrained(dir_path) elif model == "bloom": config = BloomConfig.from_pretrained(dir_path) actor = BLOOMActor(config=config) critic = BLOOMCritic(config=config) reward_model = BLOOMRM(config=config) - tokenizer = BloomTokenizerFast.from_pretrained(dir_path) + BloomTokenizerFast.from_pretrained(dir_path) elif model == "opt": config = AutoConfig.from_pretrained(dir_path) actor = OPTActor(config=config) critic = OPTCritic(config=config) reward_model = OPTRM(config=config) - tokenizer = AutoTokenizer.from_pretrained(dir_path) + AutoTokenizer.from_pretrained(dir_path) else: raise NotImplementedError(f"Model {model} not implemented") @@ -59,17 +59,12 @@ def test_init(model: str, dir_path: str): exit(0) repo_list = { - "gpt2": HFRepoFiles( - repo_id="gpt2", - files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"] - ), + "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]), "bloom": HFRepoFiles( - repo_id="bigscience/bloom-560m", - files=["config.json", "tokenizer.json", "tokenizer_config.json"] + repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"] ), "opt": HFRepoFiles( - repo_id="facebook/opt-350m", - files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] + repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] ), } diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py index 8d2fbba955b8..7e03b2d54260 100644 --- a/applications/Chat/examples/generate_conversation_dataset.py +++ b/applications/Chat/examples/generate_conversation_dataset.py @@ -31,9 +31,11 @@ def generate_alpaca(): def generate_sharegpt(): # ShareGPT data requires less processing. conversation_dataset = [] - dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered", - data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", - split="train") + dataset = load_dataset( + "anon8231489123/ShareGPT_Vicuna_unfiltered", + data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", + split="train", + ) conversations = dataset["conversations"] @@ -43,23 +45,24 @@ def generate_sharegpt(): del conv["markdown"] del conv["text"] - conversation = dict(type="conversation", - language="Multilingual", - dataset="ShareGPT", - conversations=conversations[idx]) + conversation = dict( + type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx] + ) conversation_dataset.append(conversation) return conversation_dataset -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dataset', - type=str, - default="All", - choices=["Alpaca", "ShareGPT", "All"], - help="which dataset to convert, All will combine Alpaca and ShareGPT") - parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset") + parser.add_argument( + "--dataset", + type=str, + default="All", + choices=["Alpaca", "ShareGPT", "All"], + help="which dataset to convert, All will combine Alpaca and ShareGPT", + ) + parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset") args = parser.parse_args() conversation_dataset = [] @@ -75,5 +78,5 @@ def generate_sharegpt(): for idx, sample in enumerate(conversation_dataset): sample["id"] = idx + 1 - with open(args.save_path, mode='w') as f: + with open(args.save_path, mode="w") as f: json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False) diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py index 2abb31c09f82..4eec6feae505 100644 --- a/applications/Chat/examples/generate_prompt_dataset.py +++ b/applications/Chat/examples/generate_prompt_dataset.py @@ -6,7 +6,7 @@ def sample(args): - with open(args.dataset_path, mode='r') as f: + with open(args.dataset_path, mode="r") as f: dataset_list = json.load(f) sampled_dataset = [ @@ -14,18 +14,14 @@ def sample(args): for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) ] - with open(args.save_path, mode='w') as f: - json.dump(sampled_dataset, f, indent=4, - default=str, ensure_ascii=False) + with open(args.save_path, mode="w") as f: + json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dataset_path', type=str, default=None, - required=True, help="path to the pretrain dataset") - parser.add_argument('--save_path', type=str, default='prompt.json', - help="path to save the prompt dataset") - parser.add_argument('--sample_size', type=int, - default=16384, help="size of the prompt dataset") + parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset") + parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset") + parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset") args = parser.parse_args() sample(args) diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index e1e57e3cd376..087c49564e43 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -11,13 +11,13 @@ def eval(args): # configure model - if args.model == 'gpt2': + if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': + elif args.model == "bloom": actor = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': + elif args.model == "opt": actor = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': + elif args.model == "llama": actor = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -28,45 +28,38 @@ def eval(args): actor.load_state_dict(state_dict) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') actor.eval() - input_ids = tokenizer.encode(args.input, - return_tensors='pt')\ - .to(torch.cuda.current_device()) - outputs = generate(actor, - input_ids, - max_length=args.max_length, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1) - output = tokenizer.batch_decode(outputs[0], - skip_special_tokens=True) + input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) + outputs = generate( + actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1 + ) + output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) print(f"[Output]: {''.join(output)}") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') - parser.add_argument('--max_length', type=int, default=100) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--input", type=str, default="Question: How are you ? Answer:") + parser.add_argument("--max_length", type=int, default=100) args = parser.parse_args() eval(args) diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py index 5dd52f1790e6..8de6219ec4e9 100644 --- a/applications/Chat/examples/ray/1mmt_prompt.py +++ b/applications/Chat/examples/ray/1mmt_prompt.py @@ -5,7 +5,6 @@ import pandas as pd import ray -import torch from coati.quant import llama_load_quant, low_resource_init from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.experience_maker_holder import ExperienceMakerHolder @@ -23,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -37,22 +36,25 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) env_info_maker = { - 'local_rank': '0', - 'rank': '0', - 'world_size': '1', - 'master_port': maker_port, - 'master_addr': master_addr + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, } # configure tokenizer @@ -75,27 +77,33 @@ def trainer_model_fn(): eval_performance=True, debug=args.debug, update_lora_weights=not (args.lora_rank == 0), - ) for i, env_info_trainer in enumerate(env_info_trainers) + ) + for i, env_info_trainer in enumerate(env_info_trainers) ] def model_fn(): actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model actor_cfg = AutoConfig.from_pretrained(args.pretrain) with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, env_info=env_info_maker, @@ -130,12 +138,11 @@ def model_fn(): dataset_size = args.experience_batch_size * 4 def build_dataloader(): - def tokenize_fn(texts): - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.cuda() for k, v in batch.items()} - dataset = pd.read_csv(args.prompt_path)['prompt'] + dataset = pd.read_csv(args.prompt_path)["prompt"] dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) return dataloader @@ -144,32 +151,31 @@ def tokenize_fn(texts): ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None) - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py index 76929c9d0144..7c03a0468b02 100644 --- a/applications/Chat/examples/ray/mmmt_prompt.py +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -5,7 +5,6 @@ import pandas as pd import ray -import torch from coati.quant import llama_load_quant, low_resource_init from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.experience_maker_holder import ExperienceMakerHolder @@ -23,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -37,23 +36,29 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) - env_info_makers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_makers), - 'master_port': maker_port, - 'master_addr': master_addr - } for rank in range(args.num_makers)] + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] # configure tokenizer tokenizer = AutoTokenizer.from_pretrained(args.pretrain) @@ -63,13 +68,18 @@ def model_fn(): actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model actor_cfg = AutoConfig.from_pretrained(args.pretrain) with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model @@ -78,7 +88,7 @@ def model_fn(): experience_holder_refs = [ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[ - f'trainer{x}' + f"trainer{x}" for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) ], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), @@ -87,8 +97,8 @@ def model_fn(): kl_coef=0.1, debug=args.debug, update_lora_weights=not (args.lora_rank == 0), - # sync_models_from_trainers=True, - # generation kwargs: + # sync_models_from_trainers=True, + # generation kwargs: max_length=512, do_sample=True, temperature=1.0, @@ -128,12 +138,11 @@ def trainer_model_fn(): dataset_size = args.experience_batch_size * 4 def build_dataloader(): - def tokenize_fn(texts): - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.cuda() for k, v in batch.items()} - dataset = pd.read_csv(args.prompt_path)['prompt'] + dataset = pd.read_csv(args.prompt_path)["prompt"] dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) return dataloader @@ -148,39 +157,44 @@ def tokenize_fn(texts): for experience_holder_ref in experience_holder_refs: wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps)) - total_steps = args.experience_batch_size * args.experience_steps * \ - args.num_makers // (args.num_trainers * args.train_batch_size) + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) for trainer_ref in trainer_refs: wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None) - parser.add_argument('--num_makers', type=int, default=1) - parser.add_argument('--num_trainers', type=int, default=1) + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) parser.add_argument( - '--trainer_strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index 5d0f9f927d17..d3ea7b0c8142 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,3 +1,3 @@ pandas>=1.4.1 sentencepiece -colossalai==0.3.1 \ No newline at end of file +colossalai==0.3.1 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index d27a70a3fef6..ad688b07a7f2 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -20,28 +20,28 @@ def main(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - warnings.warn('LoRA weights should be merged with the model weights') - state_dict = torch.load(args.rm_path, map_location='cpu') + warnings.warn("LoRA weights should be merged with the model weights") + state_dict = torch.load(args.rm_path, map_location="cpu") with strategy.model_init_context(): # configure model - if args.model == 'gpt2': + if args.model == "gpt2": initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': + elif args.model == "bloom": initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': + elif args.model == "opt": initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': + elif args.model == "llama": initial_model = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -51,13 +51,13 @@ def main(args): else: rm_model_name = args.rm_model - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -68,24 +68,24 @@ def main(args): initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) - if args.model == 'gpt2': + if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'bloom': + elif args.model == "bloom": actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'opt': + elif args.model == "opt": actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'llama': + elif args.model == "llama": actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -94,12 +94,12 @@ def main(args): critic.load_state_dict(state_dict, strict=False) del state_dict - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7) else: @@ -107,22 +107,22 @@ def main(args): critic_optim = Adam(critic.parameters(), lr=1e-7) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -132,27 +132,25 @@ def main(args): prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.experience_batch_size) - - pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.pretrain_dataset, - max_datasets_size=16384, - max_length=args.max_input_len) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size + ) + + pretrain_dataset = SupervisedDataset( + tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len + ) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size) + pretrain_dataloader = DataLoader( + pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size + ) # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \ - strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model + ) # configure trainer trainer = PPOTrainer( @@ -173,50 +171,54 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - offload_inference_models=args.strategy != 'colossalai_gemini' + offload_inference_models=args.strategy != "colossalai_gemini", ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - num_collect_steps=args.num_collect_steps, - num_update_steps=args.num_update_steps) + trainer.fit( + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + ) # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--num_collect_steps', type=int, default=10) - parser.add_argument('--num_update_steps', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) - parser.add_argument('--max_input_len', type=int, default=96) - parser.add_argument('--max_seq_len', type=int, default=128) + parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2"], + default="colossalai_zero2", + help="strategy to use", + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) + parser.add_argument("--max_input_len", type=int, default=96) + parser.add_argument("--max_seq_len", type=int, default=128) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 190460bc20f6..a07f4b5ca812 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -24,24 +24,24 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): - if args.model == 'bloom': + if args.model == "bloom": model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'opt': + elif args.model == "opt": model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'gpt2': + elif args.model == "gpt2": model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'llama': + elif args.model == "llama": model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -53,36 +53,36 @@ def train(args): model.load_state_dict(state_dict) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=5e-6) else: optim = Adam(model.parameters(), lr=5e-6) # configure loss function - if args.loss_fn == 'log_sig': + if args.loss_fn == "log_sig": loss_fn = LogSigLoss() - elif args.loss_fn == 'log_exp': + elif args.loss_fn == "log_exp": loss_fn = LogExpLoss() else: raise ValueError(f'Unsupported loss function "{args.loss_fn}"') @@ -94,18 +94,18 @@ def train(args): data = load_dataset(args.dataset) if args.test: - train_data = data['train'].select(range(20)) - eval_data = data['test'].select(range(5)) + train_data = data["train"].select(range(20)) + eval_data = data["test"].select(range(5)) else: - train_data = data['train'] - eval_data = data['test'] - valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) + train_data = data["train"] + eval_data = data["test"] + valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) - if args.dataset == 'Dahoas/rm-static': + if args.dataset == "Dahoas/rm-static": train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) - elif args.dataset == 'Anthropic/hh-rlhf': + elif args.dataset == "Anthropic/hh-rlhf": train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) @@ -113,90 +113,99 @@ def train(args): raise ValueError(f'Unsupported dataset "{args.dataset}"') if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - valid_sampler = DistributedSampler(valid_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - eval_sampler = DistributedSampler(eval_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + valid_sampler = DistributedSampler( + valid_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None valid_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True) - - valid_dataloader = DataLoader(valid_dataset, - shuffle=(valid_sampler is None), - sampler=valid_sampler, - batch_size=args.batch_size, - pin_memory=True) - - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) + + valid_dataloader = DataLoader( + valid_dataset, + shuffle=(valid_sampler is None), + sampler=valid_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) + + eval_dataloader = DataLoader( + eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True + ) lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict['model'] - optim = strategy_dict['optimizer'] - lr_scheduler = strategy_dict['lr_scheduler'] - trainer = RewardModelTrainer(model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - loss_fn=loss_fn, - max_epochs=args.max_epochs) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = RewardModelTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + loss_fn=loss_fn, + max_epochs=args.max_epochs, + ) trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--dataset', - type=str, - choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], - default='Dahoas/rm-static') - parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None) - parser.add_argument('--save_path', type=str, default='rm_ckpt') - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) - parser.add_argument('--test', type=bool, default=False) + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2" + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument( + "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" + ) + parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) + parser.add_argument("--save_path", type=str, default="rm_ckpt") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) + parser.add_argument("--test", type=bool, default=False) args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index f068ea2bf5de..1729abb86a09 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -6,18 +6,18 @@ import torch.distributed as dist from coati.dataset import SFTDataset, SupervisedDataset from coati.models.bloom import BLOOMActor +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.gpt import GPTActor from coati.models.llama import LlamaActor from coati.models.opt import OPTActor -from coati.models.chatglm import ChatGLMActor from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.trainer import get_scheduler @@ -28,14 +28,14 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -44,23 +44,15 @@ def train(args): warnings.warn("Gradient checkpoint is disabled when using LoRA") args.grad_checkpoint = False with strategy.model_init_context(): - if args.model == 'bloom': - model = BLOOMActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'opt': - model = OPTActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'gpt2': - model = GPTActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'llama': - model = LlamaActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'chatglm': + if args.model == "bloom": + model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "opt": + model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "gpt2": + model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "llama": + model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "chatglm": model = ChatGLMActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -68,144 +60,157 @@ def train(args): model.to(torch.float16).to(torch.cuda.current_device()) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token - elif args.model == 'chatglm': + elif args.model == "chatglm": tokenizer = ChatGLMTokenizer.from_pretrained( - "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True) + "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True + ) else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama' and args.strategy == 'colossalai_gemini': + if args.model == "llama" and args.strategy == "colossalai_gemini": # this is a hack to deal with the resized embedding # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] + sub_module_name = ".".join(name.split(".")[:-1]) + weight_name = name.split(".")[-1] sub_module = model.get_submodule(sub_module_name) setattr(sub_module, weight_name, ColoParameter(param)) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) logger = get_dist_logger() # configure dataset - if args.dataset == 'yizhongw/self_instruct': - train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') - eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + if args.dataset == "yizhongw/self_instruct": + train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") + eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") train_dataset = SFTDataset(train_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) else: - train_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.dataset, - max_datasets_size=args.max_datasets_size, - max_length=args.max_len) + train_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=args.dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_len, + ) eval_dataset = None if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) else: eval_dataloader = None num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) - lr_scheduler = get_scheduler("cosine", - optim, - num_warmup_steps=math.ceil(max_steps * 0.03), - num_training_steps=max_steps) + lr_scheduler = get_scheduler( + "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps + ) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict['model'] - optim = strategy_dict['optimizer'] - lr_scheduler = strategy_dict['lr_scheduler'] - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) - - trainer.fit(train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - logger=logger, - use_wandb=args.use_wandb) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) + + trainer.fit( + train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb + ) # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom') - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--max_datasets_size', type=int, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--use_wandb', default=False, action='store_true') - parser.add_argument('--grad_checkpoint', default=False, action='store_true') + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"], + default="colossalai_zero2", + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--max_datasets_size", type=int, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") args = parser.parse_args() train(args) diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py index 438a1e3ef1c7..dbb5490a63dc 100644 --- a/applications/Chat/inference/benchmark.py +++ b/applications/Chat/inference/benchmark.py @@ -84,28 +84,34 @@ def evaluate( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) + parser.add_argument( + "--quant", + choices=["8bit", "4bit"], + default=None, + help="Quantization mode. Default: None (no quantization, fp16).", + ) parser.add_argument( - '--gptq_checkpoint', + "--gptq_checkpoint", default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) - if args.quant == '4bit': + if args.quant == "4bit": with low_resource_init(): config = LlamaConfig.from_pretrained(args.pretrained) model = LlamaForCausalLM(config) @@ -114,12 +120,12 @@ def evaluate( else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() total_tokens = 0 @@ -129,7 +135,7 @@ def evaluate( resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1) total_tokens += tokens print(f"Response: {resp}") - print('\n----------------------------\n') + print("\n----------------------------\n") duration = time() - start - print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s') - print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') + print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s") + print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py index 9443d4b99180..333262e538ac 100644 --- a/applications/Chat/inference/locustfile.py +++ b/applications/Chat/inference/locustfile.py @@ -1,26 +1,26 @@ -from json import JSONDecodeError - from locust import HttpUser, task -samples = [[ - dict( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - dict(instruction='continue this talk', response=''), -], [ - dict(instruction='Who is the best player in the history of NBA?', response=''), -]] +samples = [ + [ + dict( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + dict(instruction="continue this talk", response=""), + ], + [ + dict(instruction="Who is the best player in the history of NBA?", response=""), + ], +] class GenerationUser(HttpUser): - @task def generate(self): for sample in samples: - data = {'max_new_tokens': 64, 'history': sample} - with self.client.post('/generate', json=data, catch_response=True) as response: + data = {"max_new_tokens": 64, "history": sample} + with self.client.post("/generate", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index 9d6b7fabef54..7c6a61b9e7f2 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -16,7 +16,7 @@ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." MAX_LEN = 512 running_lock = Lock() @@ -36,11 +36,11 @@ class GenerationTaskReq(BaseModel): app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # set CORS -origin_spec_from_env = os.environ.get('CORS_ORIGIN', None) +origin_spec_from_env = os.environ.get("CORS_ORIGIN", None) if origin_spec_from_env is not None: # allow CORS from the specified origins - origins = os.environ['CORS_ORIGIN'].split(',') + origins = os.environ["CORS_ORIGIN"].split(",") else: # allow CORS from all origins origins = ["*"] @@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { - 'max_generate_tokens': max_new_tokens, - 'early_stopping': True, - 'top_k': top_k, - 'top_p': top_p, - 'temperature': temperature, - 'prepare_inputs_fn': model.prepare_inputs_for_generation, - 'update_model_kwargs_fn': update_model_kwargs_fn, + "max_generate_tokens": max_new_tokens, + "early_stopping": True, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "prepare_inputs_fn": model.prepare_inputs_for_generation, + "update_model_kwargs_fn": update_model_kwargs_fn, } is_first_word = True generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) @@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): if is_first_word: out_string = out_string.lstrip() is_first_word = False - elif current_sub_tokens[0].startswith('▁'): + elif current_sub_tokens[0].startswith("▁"): # whitespace will be ignored by the frontend - out_string = ' ' + out_string + out_string = " " + out_string yield out_string @@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator): if await request.is_disconnected(): break try: - yield {'event': 'generate', 'data': next(generator)} + yield {"event": "generate", "data": next(generator)} except StopIteration: - yield {'event': 'end', 'data': ''} + yield {"event": "end", "data": ""} break -@app.post('/generate/stream') -@limiter.limit('1/second') +@app.post("/generate/stream") +@limiter.limit("1/second") def generate(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) event_source = event_generator( - request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)) + request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature) + ) return EventSourceResponse(event_source) -@app.post('/generate') -@limiter.limit('1/second') +@app.post("/generate") +@limiter.limit("1/second") def generate_no_stream(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) if prompt_processor.has_censored_words(prompt): return prompt_processor.SAFE_RESPONSE inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} with running_lock: - output = model.generate(**inputs, **data.dict(exclude={'history'})) + output = model.generate(**inputs, **data.dict(exclude={"history"})) output = output.cpu() - prompt_len = inputs['input_ids'].size(1) + prompt_len = inputs["input_ids"].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = prompt_processor.postprocess_output(out_string) @@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): return out_string -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) parser.add_argument( - '--gptq_checkpoint', + "--quant", + choices=["8bit", "4bit"], default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--profanity_file', - default=None, - help='Path to profanity words list. It should be a JSON file containing a list of words.') + help="Quantization mode. Default: None (no quantization, fp16).", + ) + parser.add_argument( + "--gptq_checkpoint", + default=None, + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument( + "--profanity_file", + default=None, + help="Path to profanity words list. It should be a JSON file containing a list of words.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) @@ -161,7 +170,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): censored_words = [] prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) - if args.quant == '4bit': + if args.quant == "4bit": with low_resource_init(): config = LlamaConfig.from_pretrained(args.pretrained) model = LlamaForCausalLM(config) @@ -170,12 +179,12 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() config = uvicorn.Config(app, host=args.http_host, port=args.http_port) diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py index 23028d4959cb..9835e71894c6 100644 --- a/applications/Chat/inference/tests/test_chat_prompt.py +++ b/applications/Chat/inference/tests/test_chat_prompt.py @@ -3,41 +3,49 @@ from transformers import AutoTokenizer from utils import ChatPromptProcessor, Dialogue -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' -tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH']) +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." +tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"]) samples = [ - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 200, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 200, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 211, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 211, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n", ), - ([ - Dialogue(instruction='Who is the best player in the history of NBA?', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' + ( + [ + Dialogue(instruction="Who is the best player in the history of NBA?", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n", ), ] @@ -49,5 +57,5 @@ def test_chat_prompt_processor(): assert prompt == result -if __name__ == '__main__': +if __name__ == "__main__": test_chat_prompt_processor() diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index e8e7b05ac719..af018adf6e9d 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -20,9 +20,9 @@ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def sample_streamingly(model: nn.Module, - input_ids: torch.Tensor, - max_generate_tokens: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> Generator: - +def sample_streamingly( + model: nn.Module, + input_ids: torch.Tensor, + max_generate_tokens: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> Generator: logits_processor = prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(max_generate_tokens): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { - 'input_ids': input_ids - } + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs["logits"][:, -1, :] # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) return model_kwargs class Dialogue(BaseModel): - instruction: str = Field(min_length=1, example='Count up from 1 to 500.') - response: str = Field(example='') + instruction: str = Field(min_length=1, example="Count up from 1 to 500.") + response: str = Field(example="") -def _format_dialogue(instruction: str, response: str = ''): - return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' +def _format_dialogue(instruction: str, response: str = ""): + return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}" -STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) +STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S)) class ChatPromptProcessor: - SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' + SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt." def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): self.tokenizer = tokenizer @@ -138,42 +140,48 @@ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: if self.context_len is None: - self.context_len = len(self.tokenizer(self.context)['input_ids']) + self.context_len = len(self.tokenizer(self.context)["input_ids"]) if self.dialogue_placeholder_len is None: self.dialogue_placeholder_len = len( - self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids']) + self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"] + ) prompt = self.context # the last dialogue must be in the prompt last_dialogue = history.pop() # the response of the last dialogue is empty - assert last_dialogue.response == '' - if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False) - ['input_ids']) + max_new_tokens + self.context_len >= self.max_len: + assert last_dialogue.response == "" + if ( + len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"]) + + max_new_tokens + + self.context_len + >= self.max_len + ): # to avoid truncate placeholder, apply truncate to the original instruction - instruction_truncated = self.tokenizer(last_dialogue.instruction, - add_special_tokens=False, - truncation=True, - max_length=(self.max_len - max_new_tokens - self.context_len - - self.dialogue_placeholder_len))['input_ids'] + instruction_truncated = self.tokenizer( + last_dialogue.instruction, + add_special_tokens=False, + truncation=True, + max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len), + )["input_ids"] instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() prompt += _format_dialogue(instruction_truncated) return prompt - res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids']) + res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"]) rows = [] for dialogue in history[::-1]: text = _format_dialogue(dialogue.instruction, dialogue.response) - cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids']) + cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"]) if res_len - cur_len < 0: break res_len -= cur_len rows.insert(0, text) - prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) + prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction) return prompt def postprocess_output(self, output: str) -> str: - output = STOP_PAT.sub('', output) + output = STOP_PAT.sub("", output) return output.strip() def has_censored_words(self, text: str) -> bool: @@ -184,7 +192,6 @@ def has_censored_words(self, text: str) -> bool: class LockedIterator: - def __init__(self, it, lock: Lock) -> None: self.lock = lock self.it = iter(it) diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index eb1a77875acb..809fbd4bb86b 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1,2 +1,2 @@ pytest -colossalai==0.3.1 \ No newline at end of file +colossalai==0.3.1 diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py index a285a6dff4bf..eb44b6203ef8 100644 --- a/applications/Chat/setup.py +++ b/applications/Chat/setup.py @@ -2,40 +2,42 @@ def fetch_requirements(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] def fetch_readme(): - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() def fetch_version(): - with open('version.txt', 'r') as f: + with open("version.txt", "r") as f: return f.read().strip() setup( - name='coati', + name="coati", version=fetch_version(), - packages=find_packages(exclude=( - 'tests', - 'benchmarks', - '*.egg-info', - )), - description='Colossal-AI Talking Intelligence', + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI Talking Intelligence", long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://github.com/hpcaitech/Coati', - install_requires=fetch_requirements('requirements.txt'), - python_requires='>=3.6', + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/Coati", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", ], ) diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 3a3bf5b19cb8..e3058be2e67c 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -22,10 +22,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: return dict(input_ids=input_ids, attention_mask=attention_mask) -def train_step(strategy: Strategy, - actor: GPTActor, - actor_optim: HybridAdam, - batch_size: int = 8): +def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): data = get_data(batch_size) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) actor_output = actor(data["input_ids"], data["attention_mask"]) @@ -35,8 +32,7 @@ def train_step(strategy: Strategy, strategy.optimizer_step(actor_optim) -def run_test_checkpoint(strategy_name: str, - shard: bool): +def run_test_checkpoint(strategy_name: str, shard: bool): if strategy_name == "ddp": strategy = DDPStrategy() elif strategy_name == "colossalai_gemini": @@ -60,11 +56,9 @@ def run_test_checkpoint(strategy_name: str, dist.broadcast_object_list(rank0_dirname) rank0_dirname = rank0_dirname[0] - model_path = os.path.join( - rank0_dirname, "model" if shard else f"model.pt") + model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt") strategy.save_model(actor, model_path, only_rank0=not shard) - optim_path = os.path.join( - rank0_dirname, "optim" if shard else "optim.pt") + optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt") strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard) dist.barrier() @@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str, train_step(strategy, actor, actor_optim) -def run_dist(rank: int, - world_size: int, - port: int, - strategy_name: str, - shard: bool): +def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool): os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) @@ -93,13 +83,8 @@ def run_dist(rank: int, @pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) @pytest.mark.parametrize("shard", [False, True]) @rerun_if_address_is_in_use() -def test_checkpoint(world_size: int, - strategy_name: str, - shard: bool): - spawn(run_dist, - world_size, - strategy_name=strategy_name, - shard=shard) +def test_checkpoint(world_size: int, strategy_name: str, shard: bool): + spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard) if __name__ == "__main__": diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index f9dee1bae935..3de2cc528967 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -8,62 +8,40 @@ from coati.dataset.prompt_dataset import PromptDataset from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from datasets import load_dataset from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + SFT_DATASET = [ { - "instruction": - "Provide a list of the top 10 most popular mobile games in Asia", - "input": - "", - "output": - "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": - 0 + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0, }, { - "instruction": - "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": - "", - "output": - "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": - 1 + "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": "", + "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": 1, }, { - "instruction": - "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": - "", - "output": - "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": - 2 + "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": "", + "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": 2, }, ] PROMPT_DATASET = [ { - "instruction": - "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": - 0 - }, - { - "instruction": "Write a descriptive paragraph about a memorable vacation you went on", - "id": 1 - }, - { - "instruction": "Write a persuasive essay arguing why homework should be banned in schools", - "id": 2 - }, - { - "instruction": "Create a chart comparing the statistics on student debt in the United States.", - "id": 3 + "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."', + "id": 0, }, + {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1}, + {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2}, + {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3}, ] @@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): json.dump(PROMPT_DATASET, f) tokenizer = make_tokenizer(model) assert tokenizer.padding_side in ("left", "right") - prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name), - tokenizer=tokenizer, - max_datasets_size=max_datasets_size, - max_length=max_length) + prompt_dataset = PromptDataset( + data_path=os.path.join(tmp_dir, dataset_name), + tokenizer=tokenizer, + max_datasets_size=max_datasets_size, + max_length=max_length, + ) assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) for i in range(len(prompt_dataset)): assert isinstance(prompt_dataset[i], dict) @@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), - ("Dahoas/rm-static", None)]) +@pytest.mark.parametrize( + ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)] +) @pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_length", [32, 1024]) def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): data = load_dataset(dataset_path, data_dir=subset) - assert max_datasets_size <= len(data["train"]) \ - and max_datasets_size <= len(data["test"]) + assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"]) train_data = data["train"].select(range(max_datasets_size)) test_data = data["test"].select(range(max_datasets_size)) tokenizer = make_tokenizer(model) @@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert len(train_dataset) == len(test_dataset) == max_datasets_size for i in range(max_datasets_size): chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] - assert chosen_ids.shape == c_mask.shape == \ - reject_ids.shape == r_mask.shape == torch.Size([max_length]) + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) c_mask = c_mask.to(torch.bool) r_mask = r_mask.to(torch.bool) if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: @@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert torch.all(r_mask) chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] - assert chosen_ids.shape == c_mask.shape == \ - reject_ids.shape == r_mask.shape == torch.Size([max_length]) + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) c_mask = c_mask.to(torch.bool) r_mask = r_mask.to(torch.bool) if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: @@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert torch.all(r_mask) - @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("max_dataset_size", [2]) @@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: dataset_name = "sft_dataset.json" with open(os.path.join(tmp_dir, dataset_name), "w") as f: json.dump(SFT_DATASET, f) - sft_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=os.path.join(tmp_dir, dataset_name), - max_datasets_size=max_dataset_size, - max_length=max_length) + sft_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=os.path.join(tmp_dir, dataset_name), + max_datasets_size=max_dataset_size, + max_length=max_length, + ) assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) if isinstance(tokenizer, ChatGLMTokenizer): @@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: input_ids = sft_dataset[i]["input_ids"] labels = sft_dataset[i]["labels"] assert input_ids.shape == labels.shape == torch.Size([max_length]) - + ignore_mask = labels == IGNORE_INDEX assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model) return - + for i in range(max_dataset_size): assert isinstance(sft_dataset[i], dict) assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] input_ids = sft_dataset[i]["input_ids"] labels = sft_dataset[i]["labels"] attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) - assert input_ids.shape == labels.shape == \ - attention_mask.shape == torch.Size([max_length]) + assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length]) if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) @@ -254,13 +232,8 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: if __name__ == "__main__": test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) - test_reward_dataset(model="gpt2", - dataset_path="Anthropic/hh-rlhf", - subset="harmless-base", - max_datasets_size=8, - max_length=256) - - test_prompt_dataset(model="opt", - max_datasets_size=2, - max_length=128) + test_reward_dataset( + model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256 + ) + test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py index 071e50b90e8e..d0ea3bbd2ff5 100644 --- a/applications/Chat/tests/test_experience.py +++ b/applications/Chat/tests/test_experience.py @@ -18,7 +18,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -37,12 +37,12 @@ def make_and_consume_experience(strategy): EXPERIENCE_BATCH_SIZE = 4 SAMPLE_BATCH_SIZE = 2 - if strategy == 'ddp': + if strategy == "ddp": strategy = DDPStrategy() - elif strategy == 'colossalai-zero2': + elif strategy == "colossalai-zero2": strategy = LowLevelZeroStrategy() - elif strategy == 'colossalai-gemini': - strategy = GeminiStrategy(placement_policy='cuda') + elif strategy == "colossalai-gemini": + strategy = GeminiStrategy(placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{strategy}"') @@ -58,13 +58,11 @@ def make_and_consume_experience(strategy): # experience of all ranks should be the same for _ in range(2): data = get_data(EXPERIENCE_BATCH_SIZE) - assert gather_and_equal(data['input_ids']) - assert gather_and_equal(data['attention_mask']) - experience = experience_maker.make_experience(**data, - do_sample=True, - max_length=16, - eos_token_id=50256, - pad_token_id=50256) + assert gather_and_equal(data["input_ids"]) + assert gather_and_equal(data["attention_mask"]) + experience = experience_maker.make_experience( + **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256 + ) assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.values) @@ -75,7 +73,7 @@ def make_and_consume_experience(strategy): data_buffer.append(experience) # data buffer's data should be the same - buffer_size = torch.tensor([len(data_buffer)], device='cuda') + buffer_size = torch.tensor([len(data_buffer)], device="cuda") assert gather_and_equal(buffer_size) for item in data_buffer.items: assert gather_and_equal(item.sequences) @@ -88,7 +86,7 @@ def make_and_consume_experience(strategy): # dataloader of each rank should have the same size and different batch dataloader = strategy.setup_dataloader(data_buffer) - dataloader_size = torch.tensor([len(dataloader)], device='cuda') + dataloader_size = torch.tensor([len(dataloader)], device="cuda") assert gather_and_equal(dataloader_size) for experience in dataloader: assert not gather_and_equal(experience.sequences) @@ -100,21 +98,21 @@ def make_and_consume_experience(strategy): def run_dist(rank, world_size, port, strategy): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) make_and_consume_experience(strategy) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini']) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"]) @rerun_if_address_is_in_use() def test_experience(world_size, strategy): spawn(run_dist, world_size, strategy=strategy) -if __name__ == '__main__': - test_experience(2, 'colossalai') +if __name__ == "__main__": + test_experience(2, "colossalai") diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index b98b3615cd28..b2551ff5c0de 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -6,15 +6,16 @@ import torch.nn as nn from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.generation import generate from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.chatglm import ChatGLMActor +from coati.models.llama import LlamaActor from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) @@ -23,19 +24,24 @@ [ lambda: BLOOMActor(), lambda: GPTActor(), - # HACK: skip llama due to long execution time - # lambda: LlamaActor(), - lambda: OPTActor(), - # lambda: ChatGLMActor(), -]) - -@pytest.mark.parametrize("generate_kwargs", [{ - "max_length": 64, - "use_cache": True, - "do_sample": True, - "temperature": 1.0, - "top_k": 50, -}]) + # HACK: skip llama due to long execution time + # lambda: LlamaActor(), + lambda: OPTActor(), + # lambda: ChatGLMActor(), + ], +) +@pytest.mark.parametrize( + "generate_kwargs", + [ + { + "max_length": 64, + "use_cache": True, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + } + ], +) def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() @@ -56,7 +62,7 @@ def test_utils(): "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), - "action_mask": torch.randint(0, 2, (batch_size, num_labels)) + "action_mask": torch.randint(0, 2, (batch_size, num_labels)), } fn_output = compute_reward(**fn_input) assert fn_output.shape == (batch_size,) @@ -66,9 +72,7 @@ def test_utils(): num_labels = 10 num_actions = 2 fn_input = { - "output": { - "logits": torch.randn((batch_size, seq_len, num_labels)) - }, + "output": {"logits": torch.randn((batch_size, seq_len, num_labels))}, "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "num_actions": num_actions, } @@ -105,8 +109,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, - lora_model[i].lora_B @ lora_model[i].lora_A) + assert not torch.allclose( + old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A + ) @pytest.mark.parametrize("batch_size", [8]) @@ -116,54 +121,60 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): [ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()), - # HACK: skip llama due to long execution time - # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), - lambda: (OPTActor(), OPTCritic(), OPTRM()), - lambda: (ChatGLMActor(), None, None), -]) + # HACK: skip llama due to long execution time + # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), + lambda: (OPTActor(), OPTCritic(), OPTRM()), + lambda: (ChatGLMActor(), None, None), + ], +) @torch.no_grad() -def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], - batch_size: int, - seq_len: int): +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): actor_input = { "input_ids": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } critic_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), "action_mask": torch.randint(0, 2, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } rm_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } actor, critic, rm = models_maker() if isinstance(actor, ChatGLMActor): actor = actor.float() - tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) + tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) - actor_input ={ - "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1), - "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)) - } + actor_input = { + "input_ids": torch.cat( + ( + torch.randint(0, 100, (batch_size, seq_len // 2)), + chatglm_special_token, + torch.randint(0, 100, (batch_size, seq_len // 2 - 2)), + ), + dim=1, + ), + "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)), + } assert isinstance(actor, Actor) - base_actor_model = get_base_model(actor) + get_base_model(actor) actor_output = actor(**actor_input) assert actor_output.logits.shape[:2] == (batch_size, seq_len) if critic: assert isinstance(critic, Critic) - base_critic_model = get_base_model(critic) + get_base_model(critic) critic_output = critic(**critic_input) - assert critic_output.shape == (batch_size, ) - + assert critic_output.shape == (batch_size,) + if rm: assert isinstance(rm, RewardModel) - base_rm_model = get_base_model(rm) + get_base_model(rm) rm_output = rm(**rm_input) - assert rm_output.shape == (batch_size, ) + assert rm_output.shape == (batch_size,) @pytest.mark.parametrize("batch_size", [16]) @@ -173,39 +184,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), - "labels": torch.randint(0, num_labels, (batch_size, seq_len)) + "labels": torch.randint(0, num_labels, (batch_size, seq_len)), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = PolicyLoss() loss_input = { - "log_probs": torch.randn(batch_size,), - "old_log_probs": torch.randn(batch_size,), - "advantages": torch.randn(batch_size,) + "log_probs": torch.randn( + batch_size, + ), + "old_log_probs": torch.randn( + batch_size, + ), + "advantages": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = ValueLoss() loss_input = { - "values": torch.randn(batch_size,), - "old_values": torch.randn(batch_size,), - "reward": torch.randn(batch_size,) + "values": torch.randn( + batch_size, + ), + "old_values": torch.randn( + batch_size, + ), + "reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = LogSigLoss() loss_input = { - "chosen_reward": torch.randn(batch_size,), - "reject_reward": torch.randn(batch_size,), + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = LogExpLoss() loss_input = { - "chosen_reward": torch.randn(batch_size,), - "reject_reward": torch.randn(batch_size,), + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) if __name__ == "__main__": @@ -218,4 +249,4 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int): test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) - test_loss(batch_size=8, seq_len=128, num_labels=100) \ No newline at end of file + test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/colossalai/__init__.py b/colossalai/__init__.py index fa6f72a605c0..7da55590305b 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -6,7 +6,7 @@ except ModuleNotFoundError: # this will only happen if the user did not run `pip install` # and directly set PYTHONPATH to use Colossal-AI which is a bad practice - __version__ = '0.0.0' - print('please install Colossal-AI from https://www.colossalai.org/download or from source') + __version__ = "0.0.0" + print("please install Colossal-AI from https://www.colossalai.org/download or from source") -__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__'] +__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"] diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py index 4049be79c70f..e8ba88b0406d 100644 --- a/colossalai/_analyzer/_subclasses/_meta_registration.py +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from packaging import version @@ -24,25 +24,23 @@ def new(*args, **kwargs): - return orig_empty(*args, **kwargs, device=torch.device('meta')) + return orig_empty(*args, **kwargs, device=torch.device("meta")) def new_strided(*args, **kwargs): - return orig_empty_strided(*args, **kwargs, device=torch.device('meta')) + return orig_empty_strided(*args, **kwargs, device=torch.device("meta")) def new_like(*args, **kwargs): - return orig_empty_like(*args, **kwargs, device=torch.device('meta')) + return orig_empty_like(*args, **kwargs, device=torch.device("meta")) def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -54,7 +52,7 @@ def add_func(op): return wrapper -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): # ============================== Convolutions ====================================== # https://github.com/pytorch/pytorch/pull/79834 @register_meta(aten.convolution.default) @@ -69,7 +67,6 @@ def meta_conv( output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -146,7 +143,8 @@ def calc_conv_nd_return_shape( kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -180,19 +178,39 @@ def pick_memory_format(): shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) - def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): + def meta__conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, + ): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) - def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): + def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ): return new_like(input), new_like(weight), new((bias_sizes)) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -224,7 +242,6 @@ def meta_cuda_rnn( batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -240,8 +257,11 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -257,15 +277,21 @@ def meta_cuda_rnn( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) - def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): - return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( - ()) # (grad_input, grad_weight, grad_hx, grad_cx) + def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + return ( + new_like(input), + new_like(weight), + new_like(hx), + new_like(cx) if cx is not None else new(()), + ) # (grad_input, grad_weight, grad_hx, grad_cx) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @@ -278,7 +304,7 @@ def meta_cudnn_rnn_backward(input: torch.Tensor, aten.hardtanh_backward.default, ] - if version.parse(torch.__version__) < version.parse('2.0.0'): + if version.parse(torch.__version__) < version.parse("2.0.0"): _unregistered_ewise += [ aten.prelu_backward.default, ] @@ -296,37 +322,61 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) - def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, train, eps, output_mask): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.cudnn_batch_norm.default) def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) - return new_like(input), new((n_input)), new((n_input)), new( - (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) + return ( + new_like(input), + new((n_input)), + new((n_input)), + new((0), dtype=torch.uint8), + ) # (output, running_mean, running_var, reserve) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # NB: CuDNN only implements the backward algorithm for batchnorm # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) - def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm.default) def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): bs, n_input = input.size(0), input.size(1) - return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) + return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) - def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): - return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) + def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask + ): + return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) # ================================== Misc ========================================== # Maybe incorrect @@ -355,8 +405,9 @@ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Te # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) - def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): + def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq + ): return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout) # ============================== Dropout =========================================== @@ -364,14 +415,14 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens @register_meta(aten.native_dropout.default) def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): # notice that mask is bool - return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) + return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout_backward.default) def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): - return new_like(grad) # (grad_in) + return new_like(grad) # (grad_in) - if version.parse(torch.__version__) < version.parse('1.13.0'): + if version.parse(torch.__version__) < version.parse("1.13.0"): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.eye.m_out) def meta_eye(n: int, m: int, out: torch.Tensor): @@ -385,24 +436,28 @@ def meta_index_Tensor(self, indices): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result - assert len( - indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + assert ( + len(indices) <= self.ndim + ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace import torch._refs as refs diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py index b3ec98f0811f..503981409cca 100644 --- a/colossalai/_analyzer/_subclasses/_monkey_patch.py +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -1,5 +1,4 @@ import torch -import torch.distributed as dist from packaging import version __all__ = [ @@ -48,7 +47,7 @@ "scatter", ] -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): aten = torch.ops.aten # TODO: dive deep here # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index 59991dc50912..9d52c5593bb8 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -8,7 +8,7 @@ from enum import Enum, auto from functools import partial, reduce from numbers import Number -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Union import torch from packaging import version @@ -36,15 +36,15 @@ def _format_flops(flop): B = 1e9 T = 1e12 if flop < K: - return f'{flop:.2f}' + return f"{flop:.2f}" elif flop < M: - return f'{flop / K:.2f}K' + return f"{flop / K:.2f}K" elif flop < B: - return f'{flop / M:.2f}M' + return f"{flop / M:.2f}M" elif flop < T: - return f'{flop / B:.2f}B' + return f"{flop / B:.2f}B" else: - return f'{flop / T:.2f}T' + return f"{flop / T:.2f}T" def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: @@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: Returns: Number: The total number of floating point operations (FWD + BWD). """ - maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) - or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) + maybe_inplace = ( + getattr(module, "inplace", False) + or kwargs.get("inplace", False) + or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_") + ) class DummyModule(torch.nn.Module): - def __init__(self, func): super().__init__() self.func = func @@ -74,21 +76,20 @@ def forward(self, *args, **kwargs): total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} flop_counts = defaultdict(lambda: defaultdict(int)) - parents = ['Global'] + parents = ["Global"] module = module if isinstance(module, torch.nn.Module) else DummyModule(module) class FlopTensor(MetaTensor): _tensor: torch.Tensor def __repr__(self): - name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' + name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. rs = super().__torch_dispatch__(func, types, args, kwargs) @@ -115,9 +116,7 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() def create_backwards_push(name): - class PushState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -134,9 +133,7 @@ def backward(ctx, *grad_outs): return PushState.apply def create_backwards_pop(name): - class PopState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -147,14 +144,13 @@ def forward(ctx, *args): @staticmethod def backward(ctx, *grad_outs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() return grad_outs return PopState.apply def enter_module(name): - def f(module, inputs): nonlocal parents parents.append(name) @@ -165,10 +161,9 @@ def f(module, inputs): return f def exit_module(name): - def f(module, inputs, outputs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() outputs = normalize_tuple(outputs) return create_backwards_push(name)(*outputs) @@ -189,7 +184,7 @@ def display_flops(): for mod in flop_counts.keys(): print(f"Module: ", mod) for k, v in flop_counts[mod].items(): - print('\t', k, _format_flops(v)) + print("\t", k, _format_flops(v)) print() def detach_variables(r): @@ -201,7 +196,7 @@ def detach_variables(r): def wrap(r): if isinstance(r, torch.Tensor): - data_ptr_fn = getattr(r, '_tensor', r).data_ptr + data_ptr_fn = getattr(r, "_tensor", r).data_ptr r = FlopTensor(detach_variables(r)) if maybe_inplace: r = r + 0 @@ -375,8 +370,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -420,33 +418,30 @@ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number: def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): flop_mapping = { - # gemm + # gemm aten.mm.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: ewise_flop_counter(1, 0), aten.avg_pool2d.default: ewise_flop_counter(1, 0), aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), @@ -469,7 +464,7 @@ def zero_flop_jit(*args): } ewise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -485,8 +480,7 @@ def zero_flop_jit(*args): aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -509,15 +503,12 @@ def zero_flop_jit(*args): aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, - - # distribution + # distribution aten.bernoulli_.float, - - # where + # where aten.where.self, ] for op in ewise_flop_aten: diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py index 2bc212938ee0..8be97d01343e 100644 --- a/colossalai/_analyzer/_subclasses/meta_tensor.py +++ b/colossalai/_analyzer/_subclasses/meta_tensor.py @@ -3,12 +3,12 @@ import torch import torch.distributed as dist -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod -__all__ = ['MetaTensor', 'MetaTensorMode'] +__all__ = ["MetaTensor", "MetaTensorMode"] def register_storage(r, data_ptr_fn=None): @@ -28,8 +28,7 @@ def _normalize_tuple(x): # a hack of inplace execution in PyTorch def _assert_alias(func): - return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive - ) + return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive class MetaTensor(torch.Tensor): @@ -65,14 +64,15 @@ def __new__(cls, elem, device=None, data_ptr_fn=None): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=requires_grad) # deceive the frontend for aten selections + device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: val = elem.data_ptr() data_ptr_fn = lambda: val - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` register_storage(r._tensor, data_ptr_fn) @@ -81,7 +81,7 @@ def __new__(cls, elem, device=None, data_ptr_fn=None): return r def __repr__(self): - name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor' + name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @@ -97,15 +97,15 @@ def unwrap(x): x = x._tensor elif isinstance(x, torch.Tensor): device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta # here we detect whether or not the execution generates a physical copy @@ -143,21 +143,21 @@ def replace(x): nonlocal device if isinstance(x, str) or isinstance(x, _device): device = x - return torch.device('meta') + return torch.device("meta") return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, device=device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) def data_ptr(self): return self._tensor.data_ptr() @@ -177,19 +177,17 @@ class MetaTensorMode(object): """ def __init__(self): - self.torch_overrides = {} # override torch.xxx - self.dist_overrides = {} # override torch.distributed.xxx + self.torch_overrides = {} # override torch.xxx + self.dist_overrides = {} # override torch.distributed.xxx def __enter__(self): - def _dummy(*args, **kwargs): pass def _new(*args, orig_new=torch.empty, **kwargs): - return MetaTensor(orig_new(*args, **{ - **kwargs, 'device': 'meta' - }), - device=kwargs.get('device', torch.device('cpu'))) + return MetaTensor( + orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu")) + ) for func in _TorchOverrideableFactoryMethod: self.torch_overrides[func] = getattr(torch, func) diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 41d74f2e3719..cd244b22cac0 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -22,7 +22,7 @@ import colossalai from colossalai.fx._compatibility import compatibility -_register_custom_builtin('colossalai', 'import colossalai', colossalai) +_register_custom_builtin("colossalai", "import colossalai", colossalai) def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: @@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True): """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: """ Check if the node could end the ckpt region at `ckpt_level` """ - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - return node.meta['info'].activation_checkpoint[ckpt_level] is not None + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + return node.meta["info"].activation_checkpoint[ckpt_level] is not None return True @@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - ckpt_level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: @@ -156,12 +152,12 @@ def emit_ckpt_func(body, # label given by each layer, e.g. if you are currently at level (0, 1, 1) # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): + if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)): ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -174,33 +170,40 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func, - ckpt_level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + ckpt_level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") - usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n' + usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, @compatibility(is_backward_compatible=True) class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -251,7 +253,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -259,7 +261,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -281,16 +283,16 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -309,19 +311,18 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -347,82 +348,94 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] @@ -432,13 +445,13 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -447,11 +460,11 @@ def emit_node(node: Node, body): add_global(name, value) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py index 1fdedd758c01..9d3999e322b9 100644 --- a/colossalai/_analyzer/fx/graph_module.py +++ b/colossalai/_analyzer/fx/graph_module.py @@ -13,6 +13,7 @@ try: from torch.fx.graph import _PyTreeCodeGen + SUPPORT_PT_CODEGEN = True except ImportError: SUPPORT_PT_CODEGEN = False @@ -24,7 +25,6 @@ # This is a copy of torch.fx.graph_module._WrappedCall. # It should be removed when we stop supporting torch < 1.12.0. class _WrappedCall: - def __init__(self, cls, cls_call): self.cls = cls self.cls_call = cls_call @@ -50,12 +50,14 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: # constituent substrings of the error message tb_repr = traceback.format_exc() - custom_msg = ("Call using an FX-traced Module, " - f"line {err_lineno} of the traced Module's " - "generated forward function:") - before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) marker = "~" * err_line_len + "~~~ <--- HERE" - err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) # joined message return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) @@ -65,11 +67,14 @@ def __call__(self, obj, *args, **kwargs): if self.cls_call is not None: return self.cls_call(obj, *args, **kwargs) else: - return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ - topmost_framesummary: traceback.FrameSummary = \ - traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] + topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract( + traceback.walk_tb(e.__traceback__) + )[ + -1 + ] # type: ignore[arg-type] if "eval_with_key" in topmost_framesummary.filename: print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) raise e.with_traceback(None) @@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule): code. """ - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: torch.fx.Graph, - class_name: str = 'GraphModule'): + def __init__( + self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule" + ): super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): @@ -134,7 +138,7 @@ def recompile(self) -> PythonCode: if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -157,8 +161,8 @@ def recompile(self) -> PythonCode: # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -182,7 +186,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -208,10 +212,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -228,12 +232,14 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index fbe8400a437e..d2671787ea63 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, field -from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch -from torch.autograd.profiler_util import _format_memory, _format_time -from torch.fx import Graph, GraphModule, Node +from torch.autograd.profiler_util import _format_memory +from torch.fx import Node from colossalai._analyzer.envs import MeshConfig @@ -85,12 +85,12 @@ class MetaInfo: node: Node # directory - mod_dir: str = '' + mod_dir: str = "" # ctx[data_ptr] = Tensor # mark the storage for ctx.save_for_backward - global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared - curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node + global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared + curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node # should be updated after each graph manipulation # ============================== Update ==================================== @@ -100,7 +100,7 @@ class MetaInfo: inputs: Tuple[torch.Tensor] = () outputs: Tuple[torch.Tensor] = () - is_alias: Tuple[bool] = () # whether the output is an alias of input + is_alias: Tuple[bool] = () # whether the output is an alias of input # compute cost fwd_flop: Optional[int] = 0 @@ -112,29 +112,29 @@ class MetaInfo: # should keep the same whenever manipulated # ============================= Invariant ================================== - activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen to_offload: Optional[bool] = False - sharding_spec: str = 'RR' + sharding_spec: str = "RR" def __new__(cls, node: Node, **kwargs): orig_init = cls.__init__ # if initialized, return the existing one # should disable the __init__ function - if node.meta.get('info', None) is not None: + if node.meta.get("info", None) is not None: def _dummy(self, *args, **kwargs): - if getattr(self, '_is_init', False): + if getattr(self, "_is_init", False): self._is_init = True orig_init(self, *args, **kwargs) cls.__init__ = orig_init cls.__init__ = _dummy - return node.meta['info'] + return node.meta["info"] return super().__new__(cls) def __post_init__(self): - self.node.meta['info'] = self + self.node.meta["info"] = self @property def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): @@ -188,24 +188,26 @@ def backward_size(self): return compute_size_in_bytes(self.inputs) def __repr__(self): - s = f'Node {self.node.name}' + s = f"Node {self.node.name}" if self.parameters: - s += f'\n\thas parameter of size {_format_memory(self.param_size)}' + s += f"\n\thas parameter of size {_format_memory(self.param_size)}" if self.buffers: - s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' + s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}" if self.output_size: - s += f'\n\thas output activation of size {_format_memory(self.output_size)}' + s += f"\n\thas output activation of size {_format_memory(self.output_size)}" # if self.total_size: # s += f'\n\thas total activation of size {_format_memory(self.total_size)}' if self.temp_size: - s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' + s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}" if self.backward_size: - s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}' - s += f'\n\tfwd_flop = {self.fwd_flop}'\ - f'\n\tbwd_flop = {self.bwd_flop}'\ - f'\n\tfwd_comm = {self.fwd_comm}'\ - f'\n\tbwd_comm = {self.bwd_comm}'\ - f'\n\tto_recompute = {self.to_recompute}'\ - f'\n\tto_offload = {self.to_offload}'\ - f'\n\tsharding_spec = {self.sharding_spec}' + s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}" + s += ( + f"\n\tfwd_flop = {self.fwd_flop}" + f"\n\tbwd_flop = {self.bwd_flop}" + f"\n\tfwd_comm = {self.fwd_comm}" + f"\n\tbwd_comm = {self.bwd_comm}" + f"\n\tto_recompute = {self.to_recompute}" + f"\n\tto_offload = {self.to_offload}" + f"\n\tsharding_spec = {self.sharding_spec}" + ) return s diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py index c3e760b31e96..158ebce219cd 100644 --- a/colossalai/_analyzer/fx/passes/graph_profile.py +++ b/colossalai/_analyzer/fx/passes/graph_profile.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch import torch.fx -from torch.autograd.profiler_util import _format_memory, _format_time +from torch.autograd.profiler_util import _format_memory from torch.fx import GraphModule from torch.fx.node import Argument, Node, Target @@ -13,14 +13,14 @@ def _format_flops(flops: float) -> str: """Returns a formatted FLOP size string""" if flops > 1e12: - return f'{flops / 1e12:.2f} TFLOPs' + return f"{flops / 1e12:.2f} TFLOPs" elif flops > 1e9: - return f'{flops / 1e9:.2f} GFLOPs' + return f"{flops / 1e9:.2f} GFLOPs" elif flops > 1e6: - return f'{flops / 1e6:.2f} MFLOPs' + return f"{flops / 1e6:.2f} MFLOPs" elif flops > 1e3: - return f'{flops / 1e3:.2f} kFLOPs' - return f'{flops} FLOPs' + return f"{flops / 1e3:.2f} kFLOPs" + return f"{flops} FLOPs" def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]: @@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter): Fetch shape argument from ``ShapeProp`` without re-executing the ``GraphModule`` from scratch. """ + _profileable = [ - 'call_function', - 'call_module', - 'call_method', + "call_function", + "call_module", + "call_method", ] def __init__(self, module: GraphModule, garbage_collect_values: bool = True): @@ -77,14 +78,13 @@ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_pr self.args_iter: Iterator[Any] = iter(args) for node in self.module.graph.nodes: - - self.run_node(node) # No need to store. + self.run_node(node) # No need to store. if self.garbage_collect_values: for to_delete in self.user_to_last_uses.get(node, []): del self.env[to_delete] - if node.op == 'output': + if node.op == "output": output_val = self.env[node] return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val @@ -133,9 +133,11 @@ def summary(self) -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) # Build up a list of summary information for each node node_summaries: List[List[Any]] = [] @@ -145,36 +147,38 @@ def summary(self) -> str: node: Node n_info = MetaInfo(node) last_n_info = last_n_info or n_info - node_summaries.append([ - node.op, - str(node), - _format_memory(n_info.accumulate_size), - _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), - _format_memory(n_info.output_size), - _format_memory(n_info.temp_size), - _format_memory(n_info.param_size), - _format_memory(n_info.backward_size), - _format_flops(n_info.fwd_flop), - _format_flops(n_info.bwd_flop), - ]) + node_summaries.append( + [ + node.op, + str(node), + _format_memory(n_info.accumulate_size), + _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), + _format_memory(n_info.output_size), + _format_memory(n_info.temp_size), + _format_memory(n_info.param_size), + _format_memory(n_info.backward_size), + _format_flops(n_info.fwd_flop), + _format_flops(n_info.bwd_flop), + ] + ) last_n_info = n_info # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Accumulate size', - 'Incremental size', - 'Output size', - 'Temp size', - 'Param size', - 'Backward size', - 'Fwd FLOPs', - 'Bwd FLOPs', + "Op type", + "Op", + "Accumulate size", + "Incremental size", + "Output size", + "Temp size", + "Param size", + "Backward size", + "Fwd FLOPs", + "Bwd FLOPs", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") class CommunicationProfiler(GraphProfiler): @@ -222,6 +226,7 @@ class with the ``@register_flop_count_impl`` decorator: >>> def my_fn_flop_count_impl(*args, **kwargs): >>> return 0, 0 """ + _custom_flop_count_impl = {} def run_node(self, n: torch.fx.Node) -> Any: @@ -246,11 +251,13 @@ def run_node(self, n: torch.fx.Node) -> Any: ( n_info.fwd_flop, n_info.bwd_flop, - ) = getattr(self, n.op)(n.target, args, kwargs) + ) = getattr( + self, n.op + )(n.target, args, kwargs) except Exception as e: raise RuntimeError( - f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. ' - f'Please refer to function\'s docstring to register the relevant profile_impl for this node!' + f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. " + f"Please refer to function's docstring to register the relevant profile_impl for this node!" ) from e # retain the autograd graph @@ -259,7 +266,7 @@ def run_node(self, n: torch.fx.Node) -> Any: return _denormalize_tuple(n_info.outputs) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the profiling result. Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be @@ -283,7 +290,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di else: return flop_count(target, *args, **kwargs) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the profiling result. @@ -301,7 +308,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict assert isinstance(target, str) return flop_count(getattr(torch.Tensor, target), *args, **kwargs) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node and return the profiling result. @@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule Returns: GraphModule: The same GraphModule with profiling information """ - for profiler_cls in (FlopProfiler, - # CommunicationProfiler, # TODO: add communication profiling - ): + for profiler_cls in ( + FlopProfiler, + # CommunicationProfiler, # TODO: add communication profiling + ): profiler = profiler_cls(module) profiler.propagate(*args, device=_current_device(module)) diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index 23e83013e02f..8d44f1d4b59d 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -54,7 +54,7 @@ def _current_device(module): try: return next(module.parameters()).device except StopIteration: - return torch.device('cpu') + return torch.device("cpu") @compatibility(is_backward_compatible=False) @@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter): >>> # do something here >>> return torch.empty(output_shape, device=output_device) """ + _custom_dispatch_func = {} _mode = MetaTensorMode() @@ -115,15 +116,14 @@ def run_node(self, n: torch.fx.Node) -> Any: r = getattr(self, n.op)(n.target, args, kwargs) def unwrap_fn(elem): - def _convert_meta(t: torch.Tensor): - if t.device == 'meta': + if t.device == "meta": return t else: - return t.to('meta') + return t.to("meta") if isinstance(elem, MetaTensor): - if getattr(self, '_is_param', False): + if getattr(self, "_is_param", False): return torch.nn.Parameter(_convert_meta(elem._tensor)) return _convert_meta(elem._tensor) @@ -139,21 +139,24 @@ def _convert_meta(t: torch.Tensor): n_info = MetaInfo(n) n_info.outputs = _normalize_tuple(r) - if n.op == 'call_module': + if n.op == "call_module": submod = self.fetch_attr(n.target) n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()}) n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()}) else: - n_info.parameters.update({ - k.name: MetaTensor(v) - for k, v in zip(n.args, args) - if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) - }) + n_info.parameters.update( + { + k.name: MetaTensor(v) + for k, v in zip(n.args, args) + if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) + } + ) n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)}) - n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ - tuple(v for v in kwargs.values() if is_pure_tensor(v)) + n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple( + v for v in kwargs.values() if is_pure_tensor(v) + ) # align with SPMD if isinstance(r, (tuple, list)): @@ -168,7 +171,7 @@ def _convert_meta(t: torch.Tensor): n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs)) return r - def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the result. If the target of ``Node`` is registered with ``@register_shape_impl``, @@ -197,7 +200,7 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st else: return res - def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the result. @@ -218,7 +221,8 @@ def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, convert_to_parameter = False if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance( - args[0], torch.nn.parameter.Parameter): + args[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # Execute the method and return the result assert isinstance(target, str) diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py index dd7f22c6c98a..5732a6665f78 100644 --- a/colossalai/_analyzer/fx/symbolic_profile.py +++ b/colossalai/_analyzer/fx/symbolic_profile.py @@ -1,5 +1,3 @@ -import torch -import torch.fx from torch.fx import GraphModule from .passes import ShapeProp, graph_profile_pass, shape_prop_pass @@ -7,7 +5,6 @@ def register_flop_count_impl(func): - def wrapper(impl): FlopProfiler._custom_flop_count_impl[func] = impl return impl @@ -16,7 +13,6 @@ def wrapper(impl): def register_shape_impl(func): - def wrapper(impl): ShapeProp._custom_dispatch_func[func] = impl return impl diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py index 1e75b47ca5b0..b8b83282b42c 100644 --- a/colossalai/_analyzer/fx/tracer/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -12,7 +12,7 @@ __all__ = [] -@register_tracer_impl(F.linear, name='_bias_addition_impl') +@register_tracer_impl(F.linear, name="_bias_addition_impl") def linear_impl(input, weight, bias=None): if bias is None: return F.linear(input, weight) @@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None): return F.linear(input, weight) + bias -@register_tracer_impl(F.conv1d, name='_bias_addition_impl') +@register_tracer_impl(F.conv1d, name="_bias_addition_impl") def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): if bias is None: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1)) + (-1, 1) + ) -@register_tracer_impl(F.conv2d, name='_bias_addition_impl') +@register_tracer_impl(F.conv2d, name="_bias_addition_impl") def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): if bias is None: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1)) + (-1, 1, 1) + ) -@register_tracer_impl(F.conv3d, name='_bias_addition_impl') +@register_tracer_impl(F.conv3d, name="_bias_addition_impl") def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): if bias is None: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') -def conv_transpose1d_impl(input, - weight, - bias=None, - stride=_single(1), - padding=_single(0), - output_padding=_single(0), - groups=1, - dilation=_single(1)): + (-1, 1, 1, 1) + ) + + +@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl") +def conv_transpose1d_impl( + input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1), +): if bias is None: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1)) - - -@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') -def conv_transpose2d_impl(input, - weight, - bias=None, - stride=_pair(1), - padding=_pair(0), - output_padding=_pair(0), - groups=1, - dilation=_pair(1)): + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1)) + + +@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl") +def conv_transpose2d_impl( + input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1) +): if bias is None: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') -def conv_transpose3d_impl(input, - weight, - bias=None, - stride=_triple(1), - padding=_triple(0), - output_padding=_triple(0), - groups=1, - dilation=_triple(1)): + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1)) + + +@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl") +def conv_transpose3d_impl( + input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1), +): if bias is None: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1, 1)) - - -@register_tracer_impl(torch.addmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl') + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1, 1)) + + +@register_tracer_impl(torch.addmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl") def addmm_impl(input, mat1, mat2, beta=1, alpha=1): if alpha != 1 and beta != 1: return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta @@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1): return F.linear(mat1, mat2.transpose(0, 1)) + input -@register_tracer_impl(torch.addbmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl') +@register_tracer_impl(torch.addbmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl") def addbmm_impl(input, batch1, batch2, beta=1, alpha=1): if alpha != 1 and beta != 1: return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py index 112c7c9637d2..ff6b55be5117 100644 --- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py +++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py @@ -4,6 +4,7 @@ try: import apex + register_leaf_module(apex.normalization.FusedLayerNorm) register_leaf_module(apex.normalization.FusedRMSNorm) register_leaf_module(apex.normalization.MixedFusedLayerNorm) diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py index ce379efdcf0d..e3e210e7d190 100644 --- a/colossalai/_analyzer/fx/tracer/proxy.py +++ b/colossalai/_analyzer/fx/tracer/proxy.py @@ -1,10 +1,8 @@ import operator -from typing import Any, Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Union import torch -import torch.nn as nn -from torch.fx import Graph, Node, Proxy, Tracer -from torch.fx.graph import _Namespace +from torch.fx import Node, Proxy from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor @@ -32,7 +30,7 @@ def meta_data(self, args): def __torch_function__(cls, orig_method, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if orig_method in cls._func_dispatch: - impl = cls._func_dispatch.pop(orig_method) # avoid recursion + impl = cls._func_dispatch.pop(orig_method) # avoid recursion proxy = impl(*args, **kwargs) cls._func_dispatch[orig_method] = impl return proxy @@ -72,7 +70,7 @@ def __getattr__(self, k): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -89,7 +87,6 @@ def __isinstancecheck__(self, type): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -102,11 +99,11 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py index 2018863f6f5f..7884fd911c86 100644 --- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py +++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Union import torch from torch.fx import Tracer @@ -8,6 +8,7 @@ try: from ..codegen import ActivationCheckpointCodeGen + SUPPORT_ACTIVATION = True except: SUPPORT_ACTIVATION = False @@ -16,7 +17,7 @@ def _default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") def _current_device(module: torch.nn.Module): @@ -144,10 +145,9 @@ def forward(self, x): if meta_args: device, orig_device = _default_device(), _current_device(root) wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, - bias_addition_split=bias_addition_split).trace(root.to(device), - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace( + root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) if trace_act_ckpt and SUPPORT_ACTIVATION: graph.set_codegen(ActivationCheckpointCodeGen()) root.to(orig_device) diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 6958a00a6a72..17dce767269d 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -20,11 +20,10 @@ def _truncate_suffix(s: str): import re # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name - return re.sub(r'_\d+$', '', s) + return re.sub(r"_\d+$", "", s) -def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): - +def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"): def wrapper(impl): assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}" getattr(ColoTracer, name)[func] = impl @@ -34,7 +33,6 @@ def wrapper(impl): def register_leaf_module_impl(module: nn.Module): - def wrapper(impl): ColoTracer._custom_leaf_module_impl[module] = impl return impl @@ -76,7 +74,7 @@ def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = Fal self.ckpt_regions = [] self.ckpt_idx = 0 - self.mod_dir = '' + self.mod_dir = "" # whether the tracer should split the bias_add ops into two ops self.bias_addition_split = bias_addition_split @@ -87,35 +85,41 @@ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: return False # user can specify which modules are leaf modules and which are not - return (type(m) not in self._custom_non_leaf_module - and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) + return type(m) not in self._custom_non_leaf_module and ( + type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name) + ) - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], - kwargs: Dict[str, Any]) -> Any: + def call_module( + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: curr_dir = self.mod_dir - self.mod_dir = 'self.' + self.path_of_module(m) + self.mod_dir = "self." + self.path_of_module(m) rst = super().call_module(m, forward, args, kwargs) self.mod_dir = curr_dir return rst - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self.disable_module_getattr = True try: attr_itr = self.root @@ -125,20 +129,21 @@ def create_proxy(self, proxy.meta_data = attr_itr finally: self.disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self.disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self.disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self.disable_module_getattr = True try: @@ -158,11 +163,12 @@ def create_node(self, *args, **kwargs) -> Node: n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -177,9 +183,7 @@ def trace(self, non_concrete_arg_names = sig_names - concrete_arg_names # update concrete args with default values for k, v in sig.parameters.items(): - if k in sig_names - meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default def _check_arg_name_valid(names: Iterable[str]): @@ -194,9 +198,9 @@ def _check_arg_name_valid(names: Iterable[str]): self.meta_args = meta_args with self._torch_factory_override(), self._tracer_override(), torch.no_grad(): - self.mod_dir = 'self' + self.mod_dir = "self" self.graph = super().trace(root, concrete_args=concrete_args) - self.mod_dir = '' + self.mod_dir = "" self.graph.lint() for node in self.graph.nodes: @@ -266,17 +270,17 @@ def _torch_factory_override(self): # override the torch factory functions to create a proxy when the method # is called during ``symbolic_trace()``. def wrap_factory_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.disable_module_getattr = True try: - proxy = self.create_proxy('call_function', target, args, kwargs) + proxy = self.create_proxy("call_function", target, args, kwargs) finally: self.disable_module_getattr = False return proxy @@ -341,10 +345,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -355,8 +362,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py index dc8499d877e1..34a20e8d67d6 100644 --- a/colossalai/amp/naive_amp/grad_scaler/__init__.py +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -2,4 +2,4 @@ from .constant_grad_scaler import ConstantGradScaler from .dynamic_grad_scaler import DynamicGradScaler -__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] +__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 0d84384a7f67..79661a44424f 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger -__all__ = ['BaseGradScaler'] +__all__ = ["BaseGradScaler"] class BaseGradScaler(ABC): @@ -30,24 +30,21 @@ def __init__(self, initial_scale: float, verbose: bool): @property def scale(self) -> Tensor: - """Returns the loss scale. - """ + """Returns the loss scale.""" return self._scale @property def inv_scale(self) -> Tensor: - """Returns the inverse of the loss scale. - """ + """Returns the inverse of the loss scale.""" return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: - """Returns the states of the gradient scaler as a dict object. - """ + """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() - state_dict['scale'] = self.scale + state_dict["scale"] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: @@ -57,7 +54,7 @@ def load_state_dict(self, state_dict: Dict) -> None: state_dict (dict): the states of the gradient scaler """ - self._scale = state_dict['scale'] + self._scale = state_dict["scale"] @abstractmethod def update(self, overflow: bool) -> None: @@ -67,8 +64,6 @@ def update(self, overflow: bool) -> None: overflow (bool): whether overflow occurs """ - pass - def log(self, message, *args, **kwargs): """Log messages. diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py index a2f518c5dd28..2ad8b51ac22c 100644 --- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- from .base_grad_scaler import BaseGradScaler -__all__ = ['ConstantGradScaler'] +__all__ = ["ConstantGradScaler"] class ConstantGradScaler(BaseGradScaler): @@ -23,4 +23,3 @@ def update(self, overflow: bool) -> None: Args: overflow (bool): whether overflow occurs """ - pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index e899b9ca4c89..65133a4b3712 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -7,7 +7,7 @@ from .base_grad_scaler import BaseGradScaler -__all__ = ['DynamicGradScaler'] +__all__ = ["DynamicGradScaler"] class DynamicGradScaler(BaseGradScaler): @@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler): verbose (bool): whether to log messages, defaults to False """ - def __init__(self, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - min_scale: Optional[float] = None, - max_scale: Optional[float] = None, - hysteresis: int = 2, - verbose: bool = False): + def __init__( + self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False, + ): super().__init__(initial_scale, verbose) if min_scale: self._min_scale = torch.cuda.FloatTensor([min_scale]) @@ -53,18 +55,17 @@ def __init__(self, self._sanity_checks() def _sanity_checks(self) -> None: - """Check if the arguments are correct. - """ + """Check if the arguments are correct.""" if self._min_scale: - assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' - assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative" + assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale" if self._max_scale: - assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' - assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' - assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' - assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' - assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative" + assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale" + assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1" + assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1" + assert self._hysteresis >= 0, "The hysteresis cannot be negative" def update(self, overflow: bool) -> None: """Update the loss scale. @@ -88,19 +89,18 @@ def update(self, overflow: bool) -> None: self.log( f"No overflow for consecutive {self._growth_interval} steps, " f"the loss scale is adjusted to {self.scale.item()}", - ranks=[0]) + ranks=[0], + ) def _backoff_scale(self) -> None: - """Decrease the loss scale - """ + """Decrease the loss scale""" self._scale = self._scale * self._backoff_factor if self._min_scale: self._scale = torch.max(self._scale, self._min_scale) def _grow_scale(self) -> None: - """Increase the loss scale - """ + """Increase the loss scale""" self._scale = self._scale * self._growth_factor if self._max_scale: @@ -108,14 +108,14 @@ def _grow_scale(self) -> None: def state_dict(self): state_dict = dict() - state_dict['scale'] = self._scale - state_dict['growth_factor'] = self._growth_factor - state_dict['backoff_factor'] = self._backoff_factor - state_dict['hysteresis'] = self._hysteresis + state_dict["scale"] = self._scale + state_dict["growth_factor"] = self._growth_factor + state_dict["backoff_factor"] = self._backoff_factor + state_dict["hysteresis"] = self._hysteresis return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_factor = state_dict['growth_factor'] - self._backoff_factor = state_dict['backoff_factor'] - self._hysteresis = state_dict['hysteresis'] + self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py index b0348e1477bb..a31811e4a567 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -3,7 +3,7 @@ from .fp16 import FP16MixedPrecisionMixin __all__ = [ - 'MixedPrecisionMixin', - 'FP16MixedPrecisionMixin', - 'BF16MixedPrecisionMixin', + "MixedPrecisionMixin", + "FP16MixedPrecisionMixin", + "BF16MixedPrecisionMixin", ] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index a52a9747ad1e..fc7e0b74179a 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -39,6 +39,7 @@ def zero_grad(self): return self.optim.zero_grad() ``` """ + dtype: torch.dtype @abstractmethod @@ -51,7 +52,6 @@ def pre_backward(self, loss: Tensor) -> Tensor: Returns: Tensor: Loss value (possibly scaled). """ - pass @abstractmethod def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: @@ -64,7 +64,6 @@ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: Returns: Tensor: Gradient of the tensor (possibly scaled). """ - pass @abstractmethod def should_skip_step(self) -> bool: @@ -73,13 +72,10 @@ def should_skip_step(self) -> bool: Returns: bool: Whether to skip the step. """ - pass @abstractmethod def pre_zero_grad(self) -> None: - """Called before zero_grad. - """ - pass + """Called before zero_grad.""" @abstractmethod def get_grad_div_scale(self) -> float: @@ -88,4 +84,3 @@ def get_grad_div_scale(self) -> float: Returns: float: A divisor for gradient clipping or step. """ - pass diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 1ce8e42eb3ed..9ce272356797 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -19,22 +19,26 @@ class OptimState(Enum): class FP16MixedPrecisionMixin(MixedPrecisionMixin): dtype = torch.float16 - def __init__(self, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: + def __init__( + self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: super().__init__() - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self.optim_state = OptimState.UNSCALED self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) @@ -49,7 +53,6 @@ def check_local_overflow(self) -> bool: Returns: bool: Whether there is overflow in the local process. """ - pass def check_overflow(self) -> bool: # clear previous overflow record @@ -79,6 +82,6 @@ def pre_zero_grad(self) -> None: pass def get_grad_div_scale(self) -> float: - assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' + assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping" self.optim_state = OptimState.UNSCALED return self.loss_scale diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 626a00c96d04..6a192cc5cb83 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -11,18 +11,20 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - working_params: List[Parameter], - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.params = working_params def check_local_overflow(self) -> bool: @@ -33,38 +35,41 @@ def check_local_overflow(self) -> bool: class MixedPrecisionOptimizer(OptimizerWrapper): - - def __init__(self, - optim: Optimizer, - precision: str = 'fp16', - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0): + def __init__( + self, + optim: Optimizer, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + ): super().__init__(optim) - if precision == 'fp16': + if precision == "fp16": working_params = [] for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_params.append(p) - self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - elif precision == 'bf16': + self.mixed_precision = NaiveFP16MixedPrecisionMixin( + working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif precision == "bf16": self.mixed_precision = BF16MixedPrecisionMixin() else: - raise ValueError(f'Unsupported precision: {precision}') + raise ValueError(f"Unsupported precision: {precision}") if max_norm > 0.0: - raise NotImplementedError('max_norm is not supported yet.') + raise NotImplementedError("max_norm is not supported yet.") self.max_norm = max_norm self.working_to_master_map: Dict[Parameter, Tensor] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {} @@ -72,7 +77,7 @@ def __init__(self, # create master weights for group in self.optim.param_groups: master_params = [] - for p in group['params']: + for p in group["params"]: if p.requires_grad: master_p = p if p.dtype != torch.float: @@ -80,7 +85,7 @@ def __init__(self, self.working_to_master_map[p] = master_p self.master_to_working_map[master_p] = p master_params.append(master_p) - group['params'] = master_params + group["params"] = master_params def backward(self, loss: Tensor, *args, **kwargs): loss = self.mixed_precision.pre_backward(loss) @@ -101,24 +106,24 @@ def _unscale_and_clip_grads(self, total_norm: float) -> None: if self.mixed_precision is not None: div_scale = self.mixed_precision.get_grad_div_scale() - if self.max_norm > 0.: + if self.max_norm > 0.0: # norm is in fact norm*scale clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: div_scale = clip * div_scale for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue - p.grad.data.mul_(1. / div_scale) + p.grad.data.mul_(1.0 / div_scale) def _compute_grad_norm(self) -> float: - if self.max_norm <= 0.: - return 0. - grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None] + if self.max_norm <= 0.0: + return 0.0 + grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None] if len(grads) == 0: - return 0. + return 0.0 device = grads[0].device # TODO(ver217): support tp total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) @@ -130,7 +135,7 @@ def step(self, *args, **kwargs): return # prepare grads for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_param = self.master_to_working_map[p] if p is working_param: continue @@ -142,7 +147,7 @@ def step(self, *args, **kwargs): self.optim.step(*args, **kwargs) # update working params for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_param = self.master_to_working_map[p] if p is working_param: continue diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py index af4349865a7b..7de56f80525a 100644 --- a/colossalai/auto_parallel/checkpoint/build_c_ext.py +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -3,14 +3,16 @@ from setuptools import Extension, setup this_dir = os.path.dirname(os.path.abspath(__file__)) -ext_modules = [Extension( - 'rotorc', - sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], -)] +ext_modules = [ + Extension( + "rotorc", + sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")], + ) +] setup( - name='rotor c extension', - version='0.1', - description='rotor c extension for faster dp computing', + name="rotor c extension", + version="0.1", + description="rotor c extension for faster dp computing", ext_modules=ext_modules, ) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index b388d00ac553..8aaa690b333c 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -12,13 +12,13 @@ ) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -__all___ = ['CheckpointSolverBase'] +__all___ = ["CheckpointSolverBase"] def _copy_output(src: Graph, dst: Graph): """Copy the output node from src to dst""" for n_src, n_dst in zip(src.nodes, dst.nodes): - if n_src.op == 'output': + if n_src.op == "output": n_dst.meta = n_src.meta @@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module): class CheckpointSolverBase(ABC): - def __init__( self, graph: Graph, @@ -81,13 +80,10 @@ def __init__( @abstractmethod def solve(self): - """Solve the checkpointing problem and return the solution. - """ - pass + """Solve the checkpointing problem and return the solution.""" def get_node_list(self): - """Get the node list. - """ + """Get the node list.""" return [[node] for node in self.graph.nodes] def _linearize_graph(self) -> List[List[Node]]: @@ -140,8 +136,7 @@ def _is_sink() -> bool: """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) @@ -150,19 +145,22 @@ def _is_inplace(n: Node): return inplace def _is_shape_consistency(n: Node): - """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) - """ + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)""" return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] - return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( - map(_is_shape_consistency, n.users)) + return ( + not sum([v for _, v in deps.items()]) + and not any(map(_is_inplace, n.users)) + and not any(map(_is_shape_consistency, n.users)) + ) # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") @@ -187,8 +185,9 @@ def _is_shape_consistency(n: Node): region = [] # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"]) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 19b2ef5987c9..ab16cc04b730 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -8,11 +8,10 @@ from .ckpt_solver_base import CheckpointSolverBase -__all__ = ['CheckpointSolverChen'] +__all__ = ["CheckpointSolverChen"] class CheckpointSolverChen(CheckpointSolverBase): - def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. @@ -40,14 +39,14 @@ def solve(self) -> Graph: Returns: graph (Graph): The optimized graph, should be a copy of the original graph. """ - checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"] ckpt = self.grid_search() for i, seg in enumerate(ckpt): for idx in range(*seg): nodes = self.node_list[idx] for n in nodes: if n.op in checkpointable_op: - n.meta['activation_checkpoint'] = i + n.meta["activation_checkpoint"] = i return deepcopy(self.graph) def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 21c3bf0da758..d10c41ae2b96 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple from torch import Tensor from torch.fx import Graph, Node @@ -18,17 +18,18 @@ from .ckpt_solver_base import CheckpointSolverBase from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence -__all__ = ['CheckpointSolverRotor'] +__all__ = ["CheckpointSolverRotor"] class CheckpointSolverRotor(CheckpointSolverBase): - - def __init__(self, - graph: Graph, - free_memory: float = -1, - cnode: List[str] = None, - memory_slots: int = 500, - optim_multiplier: float = 1.0): + def __init__( + self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0, + ): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. @@ -85,13 +86,14 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: # backtrack try: - self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, - self.back_ptr) + self.sequence = self._backtrack( + chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr + ) self._annotate_from_sequence(self.sequence, self.node_list) except ValueError as e: # using logger to annonce that the solver is failed logger = get_dist_logger() - logger.warning(f'Checkpoint solver failed: {e}') + logger.warning(f"Checkpoint solver failed: {e}") raise ValueError if verbose: @@ -100,14 +102,19 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: return deepcopy(self.graph) def print_chain(self): - print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) for idx in range(len(self.node_list) - 1): - print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], - self.chain.btmp[idx]) - print(f'Chain = {self.chain}') + print( + self.node_list[idx], + self.chain.x[idx + 1], + self.chain.xbar[idx + 1], + self.chain.ftmp[idx], + self.chain.btmp[idx], + ) + print(f"Chain = {self.chain}") def print_sequence(self): - print(f'Sequence = {self.sequence}') + print(f"Sequence = {self.sequence}") @classmethod def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: @@ -138,14 +145,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: btime = 0 fwd_mem_peak = 0 for n in node: - assert isinstance(n, Node), f'{n} is not a Node' + assert isinstance(n, Node), f"{n} is not a Node" if n.target == runtime_apply or n.target == runtime_comm_spec_apply: # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta - xbar += n.meta['fwd_mem_out'] - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + xbar += n.meta["fwd_mem_out"] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"]) else: xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n)) # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) @@ -162,14 +169,14 @@ def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: """Extract input tensors from a Graph""" input_tensors = [] for node in graph.nodes: - if node.op == 'placeholder': - input_tensors.append(node.meta['fwd_out']) + if node.op == "placeholder": + input_tensors.append(node.meta["fwd_out"]) return input_tensors @staticmethod def _extract_unused_output(node: Node) -> int: """Extract unused output from `torch.fx.Node`""" - return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) + return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node) @staticmethod def _extract_btmp(node: List[Node]) -> int: @@ -180,8 +187,8 @@ def _extract_deps_size(): for k, v in deps.items(): k: Node if v > 0: - deps_size += k.meta['bwd_mem_out'] - if v == float('-inf'): + deps_size += k.meta["bwd_mem_out"] + if v == float("-inf"): deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) return deps_size @@ -190,12 +197,12 @@ def _extract_deps_size(): deps = {} for n in reversed(node): deps[n] = len(n.all_input_nodes) - btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"]) for child in n.users: if child in deps: deps[child] -= 1 if deps[child] <= 0: - deps[child] = float('-inf') # free + deps[child] = float("-inf") # free return btmp @staticmethod @@ -244,10 +251,11 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple: if m < mmin: cost_table[m][i][idx] = float("inf") else: - leaf_checkpoints = [(j, - sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) - for j in range(i + 1, idx + 1) - if m >= x[j]] + leaf_checkpoints = [ + (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j] + ] if leaf_checkpoints: best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) else: @@ -274,13 +282,16 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple: import os import subprocess import sys + logger = get_dist_logger() logger.info("rotorc hasn't been built! Building library...", ranks=[0]) this_dir = os.path.dirname(os.path.abspath(__file__)) result = subprocess.Popen( [ - f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", - f"--build-lib={this_dir}" + f"{sys.executable}", + f"{os.path.join(this_dir, 'build_c_ext.py')}", + "build_ext", + f"--build-lib={this_dir}", ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -294,8 +305,9 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple: return compute_table(chain, mmax) @staticmethod - def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], - back_ptr: List[Any]) -> "Sequence": + def _backtrack( + chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any] + ) -> "Sequence": """Backtrack the cost table and retrieve the optimal checkpointing strategy. Args: @@ -328,8 +340,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A if back_ptr[budget][lhs][rhs][0]: sequence += [ ForwardEnable(lhs), - CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr + ), Backward(lhs), ] else: @@ -337,8 +350,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A sequence += [ForwardCheck(lhs)] sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] sequence += [ - CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr + ), CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), ] return sequence @@ -353,8 +367,8 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): """ op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - fwd_list = op_list[:op_list.index(loss_op)] - bwd_list = op_list[op_list.index(loss_op) + 1:] + fwd_list = op_list[: op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1 :] ckpt_idx = 0 in_ckpt = False ckpt_region = [] @@ -369,7 +383,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [] @@ -377,7 +391,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [idx] @@ -397,7 +411,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardEnable): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [] @@ -405,7 +419,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [op.index] @@ -413,7 +427,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, Backward): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) in_recompute = False @@ -431,9 +445,11 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): for node in node_list: op_list += node ckpt_regions = _find_nested_ckpt_regions(op_list) - for (start_idx, end_idx) in ckpt_regions: + for start_idx, end_idx in ckpt_regions: nested_length = max( - len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1) + ) for idx in range(start_idx, end_idx + 1): - op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - - len(op_list[idx].meta['activation_checkpoint'])) + op_list[idx].meta["activation_checkpoint"] += [None] * ( + nested_length - len(op_list[idx].meta["activation_checkpoint"]) + ) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py index ab0c6c5ad38d..5f8077916433 100644 --- a/colossalai/auto_parallel/checkpoint/operation.py +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -1,20 +1,21 @@ import math from abc import ABC -from typing import Any, Iterable, List +from typing import List from torch.utils._pytree import tree_map class Chain: - - def __init__(self, - ftime: List[float], - btime: List[float], - x: List[int], - xbar: List[int], - ftmp: List[int], - btmp: List[int], - check_consistency: bool = True): + def __init__( + self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True, + ): """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. See paper https://hal.inria.fr/hal-02352969 for details. @@ -37,9 +38,14 @@ def __init__(self, raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): - return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) - and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) - and (len(self.xbar) == len(self) + 1)) + return ( + (len(self.ftime) == len(self)) + and (len(self.btime) == len(self) + 1) + and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) + and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1) + ) def __repr__(self): chain_list = [] @@ -100,7 +106,6 @@ class ForwardCheck(Forward): class Forwards(Operation): - def __init__(self, start, end): self.index = (start, end) @@ -109,9 +114,9 @@ def __repr__(self): def cost(self, chain: Chain): if chain is not None: - return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + return sum(chain.ftime[self.index[0] : self.index[1] + 1]) else: - return (self.index[1] - self.index[0] + 1) + return self.index[1] - self.index[0] + 1 def isForward(op): @@ -132,7 +137,6 @@ def cost(self, chain: Chain): class Loss(Operation): - def __init__(self): pass @@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess): class Sequence(list): - def __init__(self): super().__init__() diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 35b8c13ee8ff..2f638fa919e4 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -3,8 +3,6 @@ import torch import torch.nn as nn -from ..tensor_shard.constants import * - # list of inplace module INPLACE_MODULE = [nn.ReLU] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 0f2e9e44f91c..4234481ae2ca 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: input_tensor = next( filter( - lambda x: - (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', - args)).data + lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) + and x.name != "softmax_dim", + args, + ) + ).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - is_inplace = 1 if kwargs.get('inplace', False) else 0 + is_inplace = 1 if kwargs.get("inplace", False) else 0 flop_counter = elementwise_flop_counter(1, 0) # calculate compute cost fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: if in_place is True, we will not create a new tensor in forward - fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), - parameter=0, - temp=0, - buffer=activation_size(input_tensor) * buffer_mem_scale) + fwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=activation_size(input_tensor) * buffer_mem_scale, + ) # temp_mem_scale is for situation like softmax backward # the buffer will be removed during backward phase @@ -54,20 +58,23 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, parameter=0, temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, - buffer=0) + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(output_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index e451748512b9..0b7b51a71955 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -6,10 +6,10 @@ from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION +from ..constants import BCAST_FUNC_OP from ..registry import meta_register -__all__ = ['binary_elementwise_meta_info'] +__all__ = ["binary_elementwise_meta_info"] @meta_register.register(BCAST_FUNC_OP) @@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] + fwd_out = [torch.zeros_like(output_op_data.data, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 4336bf68363c..2f630995cdbc 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['convnd_meta_info'] +__all__ = ["convnd_meta_info"] @meta_register.register(torch.nn.Conv1d) @@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ - flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) + if has_bias + else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost # TODO: use profiler to check conv temp memory # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py index d5d80f5b3700..7c9add810fd8 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) - bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], - [weight_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]( + [output_tensor, weight_tensor], [weight_tensor] + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) @@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=0, - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0 + ) bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 94dd9143e0ae..d731f9cb4436 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -1,23 +1,15 @@ from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register -__all__ = ['linear_meta_info', 'matmul_meta_info'] +__all__ = ["linear_meta_info", "matmul_meta_info"] @meta_register.register(torch.nn.functional.linear) @@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( - [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ - flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) + + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + ) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) @@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [output_tensor, weight_tensor], (input_tensor,) + ) + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # batched gemv case 1: batched matrix-vector multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[1]], - output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors) + [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) @@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # gemv case 2: vector-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) - bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0], input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: # batched gemv case 2: vector-batched matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], - [output_tensors[0].reshape(-1)]) + [output_tensors[0].reshape(-1)], + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[0]], - output_tensors - ) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors - ) + [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), + output_tensors[0].reshape(-1), + ], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: # gemm & batched gemm case 1: batched matrix-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], - [input_tensors[1]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] - ) + [ + input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), + output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), + ], + [input_tensors[1]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: # batched gemm case 2: matrix-batched matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ - input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( - 0, 1) - ], [output_tensors[0].transpose(-2, -1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + input_tensors[0].transpose(0, 1), + ], + [output_tensors[0].transpose(-2, -1)], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], - [input_tensors[0]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] - ) - - fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + - compute_size_in_bytes(input_tensors[1]), - temp=compute_size_in_bytes(output_tensors)) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) + [ + output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + ], + [input_tensors[0]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], + ) + + fwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors), + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors), + ) elif all(len(tensor.shape) >= 3 for tensor in input_tensors): # Batched matrix-batched matrix multiplication # Fetch shape of the two inputs and see if the batch dimensions are the same _is_batch_dims_same = True if len(input_tensors[0].shape) == len(input_tensors[1].shape): - for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): + for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): if shape_0 != shape_1: _is_batch_dims_same = False break @@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # Case 1: batch dimensions are the same # Forward compute cost: C = A * B - fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ - input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( - -1, input_dim_10, input_dim_11) - ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [ + input_tensors[0].reshape(-1, input_dim_00, input_dim_01), + input_tensors[1].reshape(-1, input_dim_10, input_dim_11), + ], + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], - [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] - ) + [ + input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + ], + [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [ + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10), + ], + [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) @@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L else: # Case 2: batch dimensions are different batch_dims = output_tensors[0].shape[:-2] - extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_00, - input_dim_01, - device="meta") - extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_10, - input_dim_11, - device="meta") + extended_input_0 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta" + ) + extended_input_1 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta" + ) # Forward compute cost: C = A * B fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)] + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [extended_input_1] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], - [extended_input_0] - ) + [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [extended_input_1], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], + [extended_input_0], + ) fwd_mem_cost = MemoryCost( - activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - - compute_size_in_bytes([extended_input_0, extended_input_1]), - temp=compute_size_in_bytes([extended_input_0, extended_input_1])) + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]) + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors) + - compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1]), + ) # compute cost compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # memory cost - total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py index 12874810b13e..b1bb1d872c35 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -3,7 +3,7 @@ import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index b872fdc8bdcd..99aaa752d0a1 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info'] +__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"] @meta_register.register(torch.nn.BatchNorm1d) @@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # saved inv std and some other args indicating the status of the module # the bwd outputs are input grad, weight grad and bias grad bwd_in_args = [ - output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + output_tensor, + output_tensor, + weight_tensor, + mean_tensor, + var_tensor, + mean_tensor, + var_tensor, + 1e-5, + num_batch, ] bwd_out_args = [input_tensor, weight_tensor, bias_tensor] @@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # calculate memory cost # the fwd activation cost is output plus saved mean and saved inv std # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, mean_tensor, var_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # and saved inv std during backward phase - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([mean_tensor, var_tensor]), - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data weight_tensor = next(filter(lambda x: x.name == "weight", args)).data bias_tensor = next(filter(lambda x: x.name == "bias", args)).data - running_mean = torch.rand(input_tensor.shape[0], 1, device='meta') - running_var = torch.rand(input_tensor.shape[0], 1, device='meta') + running_mean = torch.rand(input_tensor.shape[0], 1, device="meta") + running_var = torch.rand(input_tensor.shape[0], 1, device="meta") # construct args fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor] @@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([running_mean, running_var])) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([running_mean, running_var]), - buffer=compute_size_in_bytes([running_mean, running_var])) - - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index d785dfcca9ba..21aa524bed08 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out @@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) # temp memory for backward is the index matrix to be discarded - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), - temp=compute_size_in_bytes(index_matrix)) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix), + ) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) @@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(index_matrix, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py index 97fe3c6196f5..9a2df1bd7c87 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -2,7 +2,6 @@ import torch -from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem @@ -37,15 +36,19 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, - parameter=0, - temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, + parameter=0, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, + buffer=0, + ) - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) @@ -66,14 +69,24 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor # register torch.Tensor related metainfo # (0, 0) -meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, - torch.arange])(tensor_related_metainfo(0, 0)) +meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])( + tensor_related_metainfo(0, 0) +) # (1, 0) -meta_register.register([ - torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute, - torch.Tensor.split, torch.split, torch.Tensor.view -])(tensor_related_metainfo(1, 0)) +meta_register.register( + [ + torch.Tensor.flatten, + torch.flatten, + torch.Tensor.transpose, + torch.transpose, + torch.Tensor.permute, + torch.permute, + torch.Tensor.split, + torch.split, + torch.Tensor.view, + ] +)(tensor_related_metainfo(1, 0)) # (1, 1) meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1)) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py index 5cba1b5b6e2b..107851b80d7c 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -4,7 +4,7 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register @@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) - bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), - parameter=0, - temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - - activation_size([x_tensor, y_tensor]), - buffer=0) - - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + bwd_mem_cost = MemoryCost( + activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), + parameter=0, + temp=activation_size([output_tensor]) * 3 + + activation_size([condition_tensor]) + - activation_size([x_tensor, y_tensor]), + buffer=0, + ) + + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py index 46350c4dd406..c29086f7f9d1 100644 --- a/colossalai/auto_parallel/meta_profiler/registry.py +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -1,14 +1,12 @@ -__all__ = ['Registry'] +__all__ = ["Registry"] class Registry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -21,7 +19,7 @@ def wrapper(func): return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -29,4 +27,4 @@ def has(self, source): return source in self.store -meta_register = Registry('meta') +meta_register = Registry("meta") diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py index 0eee908b48b7..109b8a220ac7 100644 --- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -2,20 +2,13 @@ import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem from colossalai.tensor.sharding_spec import ShardingSpec from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register -__all__ = ['ShardMetaInfo'] +__all__ = ["ShardMetaInfo"] class ShardMetaInfo: @@ -76,10 +69,12 @@ def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: S """ if isinstance(sharding_spec, ShardingSpec): - op_data = OperationData(name=operation_data.name, - data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), - type=operation_data.type, - logical_shape=operation_data.logical_shape) + op_data = OperationData( + name=operation_data.name, + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape, + ) elif isinstance(sharding_spec, (list, tuple)): data = operation_data.data assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}." @@ -97,8 +92,9 @@ def compute_shard_metainfo(self): """ Compute meta info based on sharding strategy and the given target function. """ - assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ - f"Meta info for {self._target} is not registered." + assert meta_register.has(self._target.__class__) or meta_register.has( + self._target + ), f"Meta info for {self._target} is not registered." if meta_register.has(self._target.__class__): # module meta_func = meta_register.get(self._target.__class__) @@ -117,11 +113,11 @@ def compute_shard_metainfo(self): # construct kwargs if self.target in INPLACE_MODULE: - kwargs = {'inplace': self.target.inplace} + kwargs = {"inplace": self.target.inplace} elif self.target in INPLACE_OPS: - kwargs = {'inplace': True} + kwargs = {"inplace": True} else: - kwargs = {'inplace': False} + kwargs = {"inplace": False} # compute metainfo with meta_func self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 353133bd6f2d..601bf2926d99 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -37,19 +37,20 @@ class AMPOptimizer(OptimizerWrapper): norm_type (float, optional): norm_type used for `clip_grad_norm`. """ - def __init__(self, - optimizer: Optimizer, - module: BaseOffloadModule, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - clipping_norm: float = 0.0, - norm_type: float = 2.0): - + def __init__( + self, + optimizer: Optimizer, + module: BaseOffloadModule, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + ): super().__init__(optimizer) self.module = module @@ -69,19 +70,21 @@ def __init__(self, self.__init__optimizer() # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._logger = get_dist_logger() def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: region = self.param_to_region[fake_param] begin, end = self.param_to_range[fake_param] @@ -92,7 +95,7 @@ def _set_grad_ptr(self): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor self.param_to_region[fake_param].cpu_grad = None @@ -130,10 +133,10 @@ def step(self, *args, **kwargs): found_inf = self._check_overflow() if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler - self._logger.info(f'Found overflow. Skip step') - self.zero_grad() # reset all gradients + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f"Found overflow. Skip step") + self.zero_grad() # reset all gradients self._update_fp16_params() return @@ -156,11 +159,10 @@ def backward(self, loss: torch.Tensor): self.module.backward(loss) def __init__optimizer(self): - for group in self.optim.param_groups: fake_params_list = list() - for param in group['params']: + for param in group["params"]: region = self.region_manager.get_region(param) fake_param = torch.nn.Parameter(torch.empty([0])) self.param_to_range[fake_param] = region.param_to_range[param] @@ -171,7 +173,7 @@ def __init__optimizer(self): if param in self.optim.state: self.optim.state[fake_param] = self.optim.state.pop(param) - group['params'] = fake_params_list + group["params"] = fake_params_list # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 5b9f74b132f3..f5e8e31f5e97 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -22,7 +22,6 @@ class BaseOffloadModule: """ def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): - self.model = model self.region_manager = region_manager self.grad_hook_list = [] @@ -91,17 +90,16 @@ def _cast_buffers(self): def parameters(self, recurse: bool = True): return self.model.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.model.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.model.named_buffers(prefix, recurse) def named_children(self): return self.model.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.model.named_modules(memo, prefix, remove_duplicate) diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py index d56166dea982..74501c184518 100644 --- a/colossalai/auto_parallel/offload/mem_optimize.py +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -14,11 +14,9 @@ from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem -def memory_optimize(model: torch.nn.Module, - inps: Dict[str, torch.Tensor], - memory_budget: float = -1.0, - solver_name: str = 'asyn'): - +def memory_optimize( + model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn" +): model = model.cpu().half() tracer = ColoTracer() assert is_compatible_with_meta() @@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module, f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" ) - if solver_name == 'syn': + if solver_name == "syn": gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) - elif solver_name == 'asyn': + elif solver_name == "asyn": gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) else: raise TypeError(f"Unknown solver name {solver_name}!") gm.recompile() - optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn") return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index 819ffbd96eb1..ea92c714ce31 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -55,13 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda") offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() - self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) + self.fp16_data[offset : offset + p_num].copy_(param.data.flatten()) + param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -83,7 +83,7 @@ def move_param_to_cuda(self): self.temp_fp32_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: alloc_storage(self.fp16_data) - self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) + self.fp16_data[: self.param_num].copy_(self.temp_fp32_data) self.fp16_data.record_stream(torch.cuda.current_stream()) self.__update_params_ptr() @@ -94,7 +94,7 @@ def move_grad_to_cpu(self): """ self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) - self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) + self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True) self.fp16_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: self.free_cuda_data() diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py index 30bfaf00d493..146dd267967d 100644 --- a/colossalai/auto_parallel/offload/region_manager.py +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -1,10 +1,11 @@ -from typing import List, Any, Dict, Tuple +from typing import Any, Dict, List, Tuple + import torch from torch.fx import Graph, Node +from .region import Region from .solver import SolverFactory from .training_simulator import TrainingSimulator -from .region import Region from .util import NodeInfo @@ -19,14 +20,9 @@ class RegionManager: cnode (List[str], optional): Common node List, should be the subset of input. """ - def __init__(self, - graph: Graph, - solver_name: str = 'asyn', - memory_budget: float = -1.0, - cnode: List[str] = None): - + def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.cnode = cnode @@ -39,7 +35,7 @@ def __init__(self, self.memory_budget = memory_budget self.solver_name = solver_name - self.require_pool: bool = solver_name == 'asyn' + self.require_pool: bool = solver_name == "asyn" self.reg_to_block: Dict[int, int] = dict() @@ -61,22 +57,19 @@ def _build_regions(self): self._post_process(solver.best_ts) def _pre_process(self): - init_region_list = self._linearize_graph() if len(self.shared_region_pairs) > 1: - raise NotImplementedError( - 'The current version only considers at most one pair of parameter sharing.') + raise NotImplementedError("The current version only considers at most one pair of parameter sharing.") elif len(self.shared_region_pairs) == 1: shared_regs = self.shared_region_pairs[0] - assert shared_regs[0].shared_rid == shared_regs[1].r_id \ - and shared_regs[1].shared_rid == shared_regs[0].r_id + assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id fst_id = shared_regs[0].r_id lst_id = shared_regs[1].r_id - regs_left_out = init_region_list[:fst_id + 1] + regs_left_out = init_region_list[: fst_id + 1] regs_right_out = init_region_list[lst_id:] - hold_regs = init_region_list[fst_id + 1:lst_id] + hold_regs = init_region_list[fst_id + 1 : lst_id] else: regs_left_out = [] regs_right_out = [] @@ -122,12 +115,9 @@ def _early_region_placement(self, ts: TrainingSimulator): it may not find a suitable region placement strategy for the given execution flow. """ - reg_flow = torch.cat( - [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) - mem_block_num = torch.max( - torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) - coexist_matrix = torch.logical_or( - ts.fwd_reg_flow, ts.bwd_reg_flow) + reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow) block_to_regs = {} for block_idx in range(mem_block_num): @@ -135,8 +125,7 @@ def _early_region_placement(self, ts: TrainingSimulator): for reg in self.region_list: if reg.r_id in self.rid_in_pool: cur_reg_appears = coexist_matrix[:, reg.r_id] - cur_reg_coexists = torch.sum( - coexist_matrix[cur_reg_appears], dim=0).bool() + cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool() for block_idx in range(mem_block_num): if not any(cur_reg_coexists[block_to_regs[block_idx]]): block_to_regs[block_idx].append(reg.r_id) @@ -145,9 +134,12 @@ def _early_region_placement(self, ts: TrainingSimulator): if reg.r_id not in self.reg_to_block: raise NotImplementedError( - f'can not find a block from the memory pool to store parameters of the region') - self.memory_pool = torch.chunk(torch.zeros(int( - mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + f"can not find a block from the memory pool to store parameters of the region" + ) + self.memory_pool = torch.chunk( + torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"), + chunks=int(mem_block_num), + ) def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: """ @@ -178,10 +170,9 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: return region_list - def _search_block_size(self, - region_list: List[Region], - search_interval_byte: int = 1024, - search_range_byte: int = 128 * 1024 ** 2) -> int: + def _search_block_size( + self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2 + ) -> int: """ Search for a suitable memory block size. @@ -208,11 +199,10 @@ def _get_wasted_mem(size_list: List[int], blk_size: int): acc_wasted += blk_size - left return acc_wasted - param_size_list = [ - region.param_size for region in region_list if region.r_id == region.shared_rid] + param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid] start_size = max(param_size_list) - min_mem_waste = float('+inf') + min_mem_waste = float("+inf") best_block_size = start_size for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): @@ -229,7 +219,7 @@ def _init_region_data(self): Initialize region data, which maps the parameters in the region to a contiguous memory space. """ - self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) + self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32) for region in self.region_list: pre_alloc_tensor = None @@ -244,8 +234,7 @@ def _init_region_data(self): region.fp16_data = shared_region.fp16_data region.fp32_data = shared_region.fp32_data region.param_to_range = shared_region.param_to_range - region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( - ) + region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach() torch.cuda.empty_cache() @@ -259,13 +248,14 @@ def _process_shared_region(self): former_reg, latter_reg = self.shared_region_pairs[0] assert latter_reg.param_num >= former_reg.param_num embedding_node = former_reg.nodes[-1] - assert embedding_node.op == 'call_module' and isinstance( - self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) + assert embedding_node.op == "call_module" and isinstance( + self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding + ) if latter_reg.param_num > former_reg.param_num: for idx, n in enumerate(latter_reg.nodes): - if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), - torch.nn.Linear)) or \ - (n.op == 'call_function' and n.target is torch.nn.functional.linear): + if ( + n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear) + ) or (n.op == "call_function" and n.target is torch.nn.functional.linear): cut_node_idx = idx + 1 break assert len(latter_reg.fp16_params) == 2 @@ -273,7 +263,7 @@ def _process_shared_region(self): for p in new_reg.fp16_params: self.param_region_map[p] = new_reg self.region_list.insert(new_reg.r_id, new_reg) - for reg in self.region_list[new_reg.r_id + 1:]: + for reg in self.region_list[new_reg.r_id + 1 :]: reg.r_id += 1 latter_reg.shared_rid = former_reg.r_id former_reg.shared_rid = latter_reg.r_id @@ -344,8 +334,8 @@ def _maybe_param_comp_start() -> bool: target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True @@ -362,14 +352,12 @@ def _is_param_comp_end() -> bool: """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule( - n.target), "inplace", False) + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace label = False @@ -378,28 +366,30 @@ def _is_inplace(n: Node): target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True elif n.op == "call_function": label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( - map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) + map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes) + ) return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) def _exception_node_handling(): # TODO meta info prop bug - if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: - n.meta['fwd_out'] = [] + if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2: + n.meta["fwd_out"] = [] # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") else: @@ -428,8 +418,8 @@ def _exception_node_handling(): ns = [] border_n_idx = region.nodes.index(act_n) if border_n_idx < len(region.nodes): - ns = region.nodes[border_n_idx + 1:] - region.nodes = region.nodes[:border_n_idx + 1] + ns = region.nodes[border_n_idx + 1 :] + region.nodes = region.nodes[: border_n_idx + 1] region_list.append(region) region_id += 1 region = Region(r_id=region_id) @@ -448,19 +438,21 @@ def _exception_node_handling(): region = Region(r_id=region_id) # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: - deps[n] = len( - [user for user in n.users if user.op != "output"]) + deps[n] = len([user for user in n.users if user.op != "output"]) # propagate param node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops - ]) or n.op == "get_attr": + if ( + len(n.all_input_nodes) + == len([node for node in n.all_input_nodes if node.name in self.only_param_ops]) + or n.op == "get_attr" + ): self.only_param_ops.append(n.name) - param_op_deps[n] = len( - [user for user in n.users if user.op != "output"]) + param_op_deps[n] = len([user for user in n.users if user.op != "output"]) # record last activation node if _is_act(n._meta_data): @@ -472,19 +464,16 @@ def _exception_node_handling(): return region_list def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): - cur_n.node_info = NodeInfo(node_id) - if cur_n.op == 'call_module': + if cur_n.op == "call_module": target = cur_n.target submod = self.root_module.get_submodule(target) for p in list(submod.parameters(recurse=False)): - if p in self.param_region_map: cur_reg.shared_rid = self.param_region_map[p].r_id self.param_region_map[p].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[p], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[p], cur_reg)) else: self.param_region_map[p] = cur_reg @@ -499,12 +488,10 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): attr_itr = getattr(attr_itr, atom) if isinstance(attr_itr, torch.nn.Parameter): - if attr_itr in self.param_region_map: cur_reg.shared_rid = self.param_region_map[attr_itr].r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[attr_itr], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg)) else: self.param_region_map[attr_itr] = cur_reg diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py index 764ac608826b..cc790dfb0891 100644 --- a/colossalai/auto_parallel/offload/runtime.py +++ b/colossalai/auto_parallel/offload/runtime.py @@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): @staticmethod def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - d2h_rid = fwd_info.get('d2h_rid', None) + d2h_rid = fwd_info.get("d2h_rid", None) if d2h_rid is not None: free_region = GlobalRuntimeInfo().region_list[d2h_rid] assert isinstance(free_region, Region) free_region.free_cuda_data() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(h2d_region, Region) @@ -38,8 +38,7 @@ def forward(ctx, input_, fwd_info, bwd_info): @staticmethod def backward(ctx, grad_output): - - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - sync_rid = fwd_info.get('sync_rid', None) + sync_rid = fwd_info.get("sync_rid", None) if sync_rid is not None: prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -87,8 +86,7 @@ def forward(ctx, input_, fwd_info, bwd_info): @staticmethod def backward(ctx, grad_output): - - sync_rid = ctx.bwd_info.get('sync_rid', None) + sync_rid = ctx.bwd_info.get("sync_rid", None) if sync_rid is not None: wait_region = GlobalRuntimeInfo().region_list[sync_rid] assert isinstance(wait_region, Region) @@ -98,7 +96,7 @@ def backward(ctx, grad_output): else: wait_region.move_param_to_cuda() - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -114,7 +112,7 @@ def backward(ctx, grad_output): def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Upload and Offload operation into runtime action. Argument: @@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be uploaded, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be uploaded during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Prefetch and Offload operation into runtime action. Argument: @@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be prefetched, waited, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be prefetched or waited during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret @@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R # forward upload fwd_info = {} if requires_upload_p_in_fwd(region_list[region.shared_rid]): - fwd_info['h2d_rid'] = region.r_id + fwd_info["h2d_rid"] = region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward upload if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info) + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] @@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): - upload_apply_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, {})) + upload_apply_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {}) + ) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node @@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ # forward prefetch fwd_info = {} if region.param_size: - fwd_info['sync_rid'] = region.r_id + fwd_info["sync_rid"] = region.r_id fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): - fwd_info['h2d_rid'] = fwd_prefetch_region.r_id + fwd_info["h2d_rid"] = fwd_prefetch_region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward prefetch if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['sync_rid'] = r_idx - 1 + bwd_info["sync_rid"] = r_idx - 1 if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: - bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", + convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info), + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] if region.bwd_prefetch_region: - bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} + bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, {}, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info) + ) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() return gm diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index 161f7ff86898..a6b4904f2617 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -1,6 +1,6 @@ import time -from typing import List, Dict, Type from abc import ABC, abstractmethod +from typing import Dict, List, Type NOT_NVML = False try: @@ -10,10 +10,11 @@ import torch from torch.fx.node import Node + from colossalai.utils.cuda import get_current_device -from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator from .region import Region +from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .util import NodeInfo, NvDevicePower @@ -49,19 +50,14 @@ class Solver(ABC): It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. """ - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - error_factor: float = 0.95) -> None: - + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None: self.region_list = region_list self.error_factor: float = error_factor if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties( - get_current_device()).total_memory * self.error_factor + self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() @@ -94,7 +90,7 @@ def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: floa if extra_cost == 0: # means data transfer overhead can be completely overlapped - return (float('inf'), total_mem_saving, peak_mem_saving) + return (float("inf"), total_mem_saving, peak_mem_saving) return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: @@ -122,9 +118,7 @@ def _update_state(self, best_ts: TrainingSimulator): self.best_ts = best_ts self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) - def _update_node_mem_info(self, - fwd_mem_info: Dict[Node, float], - bwd_mem_info: Dict[Node, float]): + def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]): """ Update the runtime memory information of the node. @@ -134,12 +128,10 @@ def _update_node_mem_info(self, """ for node, mem in fwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_fwd_mem = mem for node, mem in bwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_bwd_mem = mem def _extract_computing_power(self): @@ -159,12 +151,12 @@ def _extract_computing_power(self): return NvDevicePower.RTX3080_FP16 * units elif device_name.__contains__("RTX 3090"): return NvDevicePower.RTX3090_FP16 * units - elif device_name.__contains__('V100'): + elif device_name.__contains__("V100"): return NvDevicePower.V100_FP16 * units elif device_name.__contains__("A100"): return NvDevicePower.A100_FP16 * units else: - raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') + raise TypeError(f"Unknown NVIDIA GPU device name {device_name}") def _profile_bandwidth(self): """ @@ -172,9 +164,9 @@ def _profile_bandwidth(self): using data volumes ranging from 1KB to 1GB. """ - print('profiling bandwidth ......') + print("profiling bandwidth ......") link_to_bandwidth = {} - links = ['h2d', 'd2h'] + links = ["h2d", "d2h"] for link in links: t_size = 1024 @@ -182,24 +174,22 @@ def _profile_bandwidth(self): # from 1KB to 1GB for i in range(21): - if link == 'h2d': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, pin_memory=True) - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, device='cuda') - elif link == 'd2h': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, device='cuda') - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, pin_memory=True) + if link == "h2d": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda") + elif link == "d2h": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda") + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True) def func(): dst_tensor.copy_(src_tensor) size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) - print(f'size: {t_size / 1024 ** 2:.3f} MB, ' - f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' - f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') + print( + f"size: {t_size / 1024 ** 2:.3f} MB, " + f"{src_tensor.device.type}-to-{dst_tensor.device.type} " + f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s" + ) t_size *= 2 @@ -208,10 +198,7 @@ def func(): class SynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0) -> None: + def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None: super().__init__(region_list, memory_budget) self.best_ts: SynTrainingSimulator = None @@ -258,7 +245,8 @@ def _call_solver(self): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) def _call_solver_l2l(self): """ @@ -270,7 +258,6 @@ def _call_solver_l2l(self): region.is_syn = True def _try_to_offload(self, offload_region: Region): - # record previous information orig_need_offload = offload_region.need_offload assert not orig_need_offload @@ -297,23 +284,17 @@ def _eval_one_choice(self, offload_region: Region): ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() - extra_comm_cost = 2.0 * \ - ts._get_communication_overhead('h2d', offload_region.param_size) + extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size) # the shared region needs to be moved twice if offload_region.r_id < offload_region.shared_rid: extra_comm_cost *= 2.0 - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class AsynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - search_window_size: int = 3): + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3): super().__init__(region_list, memory_budget) self.search_window_size = search_window_size @@ -331,7 +312,7 @@ def _init_state(self): ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) - print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB") def _call_solver(self): """ @@ -358,18 +339,17 @@ def _call_solver(self): best_pref_ts = None # search when to prefetch the region offloaded - for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: + for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]: if host_region.bwd_prefetch_region is not None: continue - temp_ts, profit = self._try_to_offload( - host_region, region) + temp_ts, profit = self._try_to_offload(host_region, region) if self._compare_profit(profit, max_prefetch_profit): region_to_region_map[region.r_id] = host_region max_prefetch_profit = profit best_pref_ts = temp_ts - if profit[0] == float('inf'): + if profit[0] == float("inf"): break if self._compare_profit(max_prefetch_profit, max_offload_profit): @@ -392,7 +372,8 @@ def _call_solver(self): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) region_to_region_map.clear() @@ -452,7 +433,6 @@ def _repair_strategy(self): peak_mem_saving = 0 while len(self.region_to_region_map) and peak_mem_saving <= 0: - max_profit = (0,) best_ts = None undo_host_region = None @@ -464,8 +444,7 @@ def _repair_strategy(self): assert offload_region.need_offload assert not offload_region.is_syn - ts, profit = self._try_convert_to_syn_upload(host_region, - offload_region) + ts, profit = self._try_convert_to_syn_upload(host_region, offload_region) if self._compare_profit(profit, max_profit): undo_host_region = host_region @@ -474,7 +453,7 @@ def _repair_strategy(self): best_ts = ts if best_ts is None: - raise NotImplementedError('repair error!') + raise NotImplementedError("repair error!") assert not undo_offload_region.is_syn undo_offload_region.is_syn = True @@ -500,17 +479,13 @@ def _eval_one_choice(self): ts.execute() extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class SolverFactory: - solvers: Dict[str, Type[Solver]] = { - 'syn': SynGreedySolver, - 'asyn': AsynGreedySolver - } + solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver} @staticmethod def create(solver_name: str) -> Type[Solver]: diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py index de58023ec2d6..728d8daf9a46 100644 --- a/colossalai/auto_parallel/offload/training_simulator.py +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -1,7 +1,7 @@ import bisect -from typing import List, Dict -from collections import OrderedDict from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List from torch.fx.node import Node @@ -26,10 +26,7 @@ class TrainingSimulator(ABC): link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. """ - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: self.region_list = region_list self.region_num = len(region_list) @@ -87,11 +84,7 @@ def _get_computing_overhead(self, flop: float) -> float: class SynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) def execute(self): @@ -115,8 +108,7 @@ def _eval_fwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.fwd_node_mem[node] = self.runtime_mem self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -141,18 +133,15 @@ def _eval_bwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -160,12 +149,14 @@ def _eval_bwd_mem_per_region(self, region: Region): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameter and offload gradient in region if region.r_id == region.shared_rid: @@ -177,23 +168,16 @@ def _eval_bwd_mem_per_region(self, region: Region): class AsynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) self.iter_end_time: int = 0 # the last computation execution period - self.last_comp: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last parameter prefetch execution period - self.last_h2d: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last gradient offload execution period - self.last_d2h: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the forward computation execution period of the region self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the forward parameter prefetch execution period of the region @@ -204,10 +188,8 @@ def __init__(self, self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the gradient offload execution period of the region # which is divided into those that are waiting and those that have been released - self.bwd_reg_to_offl_waiting: OrderedDict[int, - ExecutionPeriod] = OrderedDict() - self.bwd_reg_to_offl_freed: OrderedDict[int, - ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the region buffer, which records regions that are offloaded but not released self.reg_buffer_to_free: List[int] = [] @@ -217,10 +199,8 @@ def __init__(self, # the region execution flow, # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU # when the execution reaches the i-th region. - self.fwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() - self.bwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() + self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() def execute(self): """ @@ -232,7 +212,7 @@ def execute(self): for reg in self.region_list: if reg.param_size and reg.r_id < self.region_num - 1: - for nr in self.region_list[reg.r_id + 1:]: + for nr in self.region_list[reg.r_id + 1 :]: if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): reg.fwd_prefetch_region = nr break @@ -249,8 +229,7 @@ def execute(self): self.runtime_mem -= self.region_list[reg_id].param_size self.bwd_reg_to_offl_waiting.clear() - self.iter_end_time = max( - self.last_comp.end_time, self.last_d2h.end_time) + self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time) def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ @@ -258,10 +237,8 @@ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) - pref_end_time = pref_start_time + \ - 2.0 * self._get_communication_overhead('h2d', region.param_size) - pref_ep = ExecutionPeriod( - start_time=pref_start_time, end_time=pref_end_time) + pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size) + pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time) if is_fwd: self.fwd_reg_to_pref[region.r_id] = pref_ep else: @@ -276,18 +253,16 @@ def _insert_comp_exec(self, region: Region, is_fwd: bool = True): if is_fwd: reg_to_comp = self.fwd_reg_to_comp reg_to_pref = self.fwd_reg_to_pref - flop_key = 'fwd_flop' + flop_key = "fwd_flop" else: reg_to_comp = self.bwd_reg_to_comp reg_to_pref = self.bwd_reg_to_pref - flop_key = 'bwd_flop' - comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( - region.r_id, ExecutionPeriod(0, 0)).end_time) - comp_end_time = comp_start_time + \ - sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) - for node in region.nodes]) - comp_ep = ExecutionPeriod( - start_time=comp_start_time, end_time=comp_end_time) + flop_key = "bwd_flop" + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_end_time = comp_start_time + sum( + [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes] + ) + comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time) reg_to_comp[region.r_id] = comp_ep self.last_comp = comp_ep @@ -297,10 +272,8 @@ def _insert_d2h_exec(self, region: Region): """ offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) - offl_end_time = offl_start_time + \ - self._get_communication_overhead('d2h', region.param_size) - offl_ep = ExecutionPeriod( - start_time=offl_start_time, end_time=offl_end_time) + offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size) + offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time) self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep self.last_d2h = offl_ep @@ -332,20 +305,17 @@ def _eval_fwd_mem_per_region(self, region: Region): self.fwd_reg_flow[region.r_id, region.r_id] = True else: self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] - self.fwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False self.reg_buffer_to_free.clear() # prefetch parameters of the next region fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): self.runtime_mem += fwd_prefetch_region.param_size - self.fwd_reg_flow[region.r_id, - fwd_prefetch_region.r_id] = True + self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -354,8 +324,7 @@ def _eval_fwd_mem_per_region(self, region: Region): if region.need_offload: self.runtime_mem -= region.param_size - assert len( - self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}" self.reg_buffer_to_free.append(region.r_id) def _eval_bwd_cost_per_region(self, region: Region): @@ -398,8 +367,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] else: self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] - self.bwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False # free gradients in the buffer while len(self.reg_buffer_to_free): @@ -415,8 +383,7 @@ def _eval_bwd_mem_per_region(self, region: Region): bwd_prefetch_region = region.bwd_prefetch_region if bwd_prefetch_region: self.runtime_mem += bwd_prefetch_region.param_size - self.bwd_reg_flow[region.r_id, - bwd_prefetch_region.r_id] = True + self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True # add the gradient of the parameter if region.r_id < region.shared_rid: @@ -426,10 +393,8 @@ def _eval_bwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. @@ -437,8 +402,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -446,12 +410,14 @@ def _eval_bwd_mem_per_region(self, region: Region): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameters of the region if requires_release_p_in_bwd(self.region_list[region.shared_rid]): diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py index 6b010512cc9c..cb65da79c5a2 100644 --- a/colossalai/auto_parallel/offload/util.py +++ b/colossalai/auto_parallel/offload/util.py @@ -35,7 +35,6 @@ class NvDevicePower: class GlobalRuntimeInfo(metaclass=SingletonMeta): - def __init__(self): self.h2d_stream = torch.cuda.Stream() self.d2h_stream = torch.cuda.Stream() @@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: # forward for region in region_list: for node in region.nodes: - runtime_mem = runtime_mem + \ - calculate_fwd_tmp(node) + calculate_fwd_out(node) + runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) act_peak_mem = max(runtime_mem, act_peak_mem) # backward bwd_deps = {} for region in region_list.__reversed__(): for node in region.nodes.__reversed__(): runtime_mem -= calculate_fwd_out(node) - runtime_mem = runtime_mem + \ - node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] + runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] act_peak_mem = max(runtime_mem, act_peak_mem) - runtime_mem = runtime_mem - \ - node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) + runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node) # free bwd_mem_out bwd_deps[node] = len(node.all_input_nodes) @@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: if user_node in bwd_deps: bwd_deps[user_node] -= 1 if bwd_deps[user_node] <= 0: - runtime_mem -= user_node.meta['bwd_mem_out'] + runtime_mem -= user_node.meta["bwd_mem_out"] return act_peak_mem @@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float: def requires_upload_p_in_fwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_release_p_in_bwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_offload_g_in_bwd(region: Region): diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index ffda58e0689f..ba290ee839d8 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -14,18 +14,20 @@ shape_consistency_manager = ShapeConsistencyManager() -def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, - target_sharding_spec: ShardingSpec) -> ShardMetaInfo: +def _construct_shard_meta_info( + node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec +) -> ShardMetaInfo: # get comm_action_sequence and total_cost from shape_consistency_manager _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) + origin_sharding_spec, target_sharding_spec + ) meta_info = ShardMetaInfo() # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel # get mem cost for ShardMetaInfo mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) # extract user that has _meta_data and extract element length - input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data")) element_length = input_node._meta_data.element_size() mem_cost.fwd.activation *= element_length @@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, meta_info.memory_cost = mem_cost # get computation cost for ShardMetaInfo - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) # get tensor shape for ShardMetaInfo origin_sharding_spec: ShardingSpec @@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, input_shape = origin_sharding_spec.get_sharded_shape_per_device() output_shape = target_sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] return meta_info @@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ - user_node_index] + origin_sharding_spec, target_sharding_spec = ( + origin_spec_dict[node_index], + sharding_spec_dict[node_index][user_node_index], + ) return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S # this case is for all_reduce, there will be no memory cost meta_info = ShardMetaInfo() meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) - output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + output_node = next(n for n in node.users if hasattr(n, "_meta_data")) element_length = output_node._meta_data.element_size() total_cost = comm_action.comm_spec.get_comm_cost() - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] else: # this case will be handled by shape consistency manager - origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ - 'tgt_spec'] + origin_sharding_spec, target_sharding_spec = ( + comm_action.comm_spec["src_spec"], + comm_action.comm_spec["tgt_spec"], + ) meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, - comm_actions_dict: Dict) -> GraphModule: +def comm_metainfo_pass( + gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict +) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: - setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass return gm diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 0673b767de7b..9b000549de6c 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -21,16 +21,15 @@ def _normalize_tuple(x): @compatibility(is_backward_compatible=False) class MetaInfoProp: - def __init__(self, module: GraphModule) -> None: self.module = module self.func_dict = { - 'placeholder': self.placeholder_handler, - 'get_attr': self.get_attr_handler, - 'output': self.output_handler, - 'call_function': self.node_handler, - 'call_module': self.node_handler, - 'call_method': self.node_handler, + "placeholder": self.placeholder_handler, + "get_attr": self.get_attr_handler, + "output": self.output_handler, + "call_function": self.node_handler, + "call_module": self.node_handler, + "call_method": self.node_handler, } def _set_data_ptr(self, x): @@ -46,7 +45,7 @@ def _is_inplace(self, node: Node): """ Check if the node is inplace operation. """ - if node.op == 'call_module': + if node.op == "call_module": return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -66,7 +65,7 @@ def placeholder_handler(self, node: Node) -> None: Handle the placeholder node. """ graph_info = GraphInfo() - out = _normalize_tuple(getattr(node, '_meta_data', None)) + out = _normalize_tuple(getattr(node, "_meta_data", None)) graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @@ -96,7 +95,7 @@ def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ - assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" + assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_strategy_info meta_info: ShardMetaInfo @@ -126,7 +125,8 @@ def node_handler(self, node: Node) -> None: for tensor in par.meta.get("fwd_out", []): tensor: torch.Tensor target_input_tensor = next( - (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None + ) if target_input_tensor is not None: target_input_tensor.data_ptr = tensor.data_ptr diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 2049a06187d2..27afe72c0db8 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -1,18 +1,10 @@ -from copy import deepcopy from typing import Dict, List import torch from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationData, - OperationDataType, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) -def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, - user_node_index: int): +def runtime_apply_for_iterable_object( + node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int +): """ This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list is converted into the user node expected form. """ rst = [] - for index, (origin_sharding_spec, - target_sharding_spec) in enumerate(zip(origin_dict[node_index], - input_dict[node_index][user_node_index])): + for index, (origin_sharding_spec, target_sharding_spec) in enumerate( + zip(origin_dict[node_index], input_dict[node_index][user_node_index]) + ): rst.append( - shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, - target_sharding_spec)) + shape_consistency_manager.apply_for_autoparallel_runtime( + node[index], origin_sharding_spec, target_sharding_spec + ) + ) rst = type(node)(rst) return rst @@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_ if isinstance(comm_action.comm_spec, CommSpec): rst = comm_action.comm_spec.covert_spec_to_action(tensor) else: - origin_sharding_spec = comm_action.comm_spec['src_spec'] - tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + origin_sharding_spec = comm_action.comm_spec["src_spec"] + tgt_sharding_spec = comm_action.comm_spec["tgt_spec"] rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) return rst @@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if node.target == 'comm_actions_dict': + if node.target == "comm_actions_dict": comm_actions_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): if isinstance(node.sharding_spec, (list, tuple)): assert isinstance( - node.target_sharding_specs, - (list, - tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + node.target_sharding_specs, (list, tuple) + ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list" total_difference = 0 - for sharding_spec, target_sharding_spec in zip(node.sharding_spec, - node.target_sharding_specs[user_node_index]): + for sharding_spec, target_sharding_spec in zip( + node.sharding_spec, node.target_sharding_specs[user_node_index] + ): total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) if total_difference == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply_for_iterable_object, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) else: - assert isinstance(node.sharding_spec, - ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + assert isinstance( + node.sharding_spec, ShardingSpec + ), "node.sharding_spec should be type of ShardingSpec, tuple or list." if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) - if hasattr(user_node.meta['info'], 'activation_checkpoint'): - MetaInfo(shape_consistency_node, - mod_dir=user_node.meta['info'].mod_dir, - activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) + if hasattr(user_node.meta["info"], "activation_checkpoint"): + MetaInfo( + shape_consistency_node, + mod_dir=user_node.meta["info"].mod_dir, + activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint), + ) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) # the origin node may be a positional argument or key word argument of user node @@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue comm_actions = node.best_strategy.communication_actions for op_data, comm_action in comm_actions.items(): - if comm_action.comm_type == CommType.HOOK: continue if comm_action.comm_type == CommType.BEFORE: @@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): else: comm_object = node.args[comm_action.arg_index] with mod_graph.inserting_before(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(comm_object, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) # the origin node may be a positional argument or key word argument of user node if comm_action.key_for_kwarg is not None: # substitute the origin node with comm_spec_apply_node @@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): elif comm_action.comm_type == CommType.AFTER: with mod_graph.inserting_after(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(node, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) user_list = list(node.users.keys()) for user in user_list: if user == comm_spec_apply_node: @@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(comm_spec_apply_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + comm_spec_apply_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) return gm @@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule): nodes = tuple(mod_graph.nodes) for node in nodes: - if not hasattr(node.meta, 'activation_checkpoint'): - from .runtime_preparation_pass import size_processing + if not hasattr(node.meta, "activation_checkpoint"): + pass user_act_annotation = -1 input_act_annotation = -1 for user_node in node.users.keys(): - if 'activation_checkpoint' in user_node.meta: - user_act_annotation = user_node.meta['activation_checkpoint'] + if "activation_checkpoint" in user_node.meta: + user_act_annotation = user_node.meta["activation_checkpoint"] break for input_node in node._input_nodes.keys(): - if 'activation_checkpoint' in input_node.meta: - input_act_annotation = input_node.meta['activation_checkpoint'] + if "activation_checkpoint" in input_node.meta: + input_act_annotation = input_node.meta["activation_checkpoint"] break if user_act_annotation == input_act_annotation and user_act_annotation != -1: - node.meta['activation_checkpoint'] = user_act_annotation + node.meta["activation_checkpoint"] = user_act_annotation return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 0ed0742ee57e..65c3d8e0cbeb 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -1,19 +1,12 @@ import operator -from copy import deepcopy from typing import Dict, List, Union import torch -from torch.fx import symbolic_trace from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationDataType, - ShardingStrategy, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import _all_reduce @@ -25,11 +18,13 @@ shape_consistency_manager = ShapeConsistencyManager() -def size_processing(size: Union[int, torch.Size], - dim_partition_dict: Dict[int, List[int]], - device_mesh_info: Dict[int, int], - target_dim: int = None, - node_name: str = None): +def size_processing( + size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None, +): """ This method will be invoked during runtime to convert size node value depending on distributed information. """ @@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size], return size -def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], - strategies_constructor: StrategiesConstructor): +def solution_annotation_pass( + gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor +): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. @@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector # stick the solution strategy to the corresponding node - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( - str(node)) + str(node) + ) # attach the corresponding metainfo if node has the attribute `strategies_info` - if hasattr(node, 'strategies_info'): - setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) + if hasattr(node, "strategies_info"): + setattr(node, "best_strategy_info", node.strategies_info[strategy_index]) # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} @@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs - setattr(node, 'target_sharding_specs', target_sharding_specs) + setattr(node, "target_sharding_specs", target_sharding_specs) # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. - if node.op == 'get_attr': - assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + if node.op == "get_attr": + assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version." target_node = node.strategies_vector.successor_nodes[0] node_name = str(node) - if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP: node_name = str(target_node) target_node = target_node.strategies_vector.successor_nodes[0] user_strategy = target_node.best_strategy @@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') - comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") + comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict") break return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict @@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh device_mesh_info[dim] = dim_size def _extract_target_dim(node): - ''' + """ A helper function to extract the target dimension from size node. There are two usages of torch.Tensor.size: 1. tensor.size() @@ -156,7 +153,7 @@ def _extract_target_dim(node): If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. Otherwise, the output will be in type of torch.Size and this function will return None. - ''' + """ target_dim = None if len(node.args) > 1: target_dim = node.args[1] @@ -165,19 +162,21 @@ def _extract_target_dim(node): return target_dim def _post_processing(node, size_processing_node): - ''' + """ This function is used to process the dependency between the size node and its users after inserting the size_process_node. - ''' + """ # store original node and processing node pair in node_pairs dictionary # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(size_processing_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + size_processing_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) user_list = list(node.users.keys()) for user in user_list: @@ -196,10 +195,10 @@ def _post_processing(node, size_processing_node): user.kwargs = new_kwargs def _update_slice_object_args(slice_object): - ''' + """ This function is used to update the slice object argument list. If the slice object contains the Node argument, then the size node will be replaced with - ''' + """ if isinstance(slice_object, slice): start = slice_object.start stop = slice_object.stop @@ -220,8 +219,7 @@ def _update_slice_object_args(slice_object): raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") for node in nodes: - - if node.op == 'call_method' and node.target == 'size': + if node.op == "call_method" and node.target == "size": # extract useful information from size node # dim_partition_dict will instruct the size value on which # dimension should be enlarged. @@ -232,14 +230,14 @@ def _update_slice_object_args(slice_object): # insert size_processing node with mod_graph.inserting_after(node): - size_processing_node = mod_graph.create_node('call_function', - size_processing, - args=(node, dim_partition_dict, device_mesh_info, - target_dim, node.name)) + size_processing_node = mod_graph.create_node( + "call_function", + size_processing, + args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name), + ) _post_processing(node, size_processing_node) - if node.op == 'call_function' and node.target == operator.getitem: - + if node.op == "call_function" and node.target == operator.getitem: getitem_index = node.args[1] # slice object is quite special in torch.fx graph, # On one side, we treat slice object same as type of int, @@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) nodes = tuple(mod_graph.nodes) def _extract_info_from_sharding_spec(sharding_spec): - ''' + """ This function is used to extract the dim_partition_dict and device_mesh from sharding spec instance or a list of sharding spec. - ''' + """ if isinstance(sharding_spec, ShardingSpec): dim_partition_dict = sharding_spec.dim_partition_dict device_mesh = sharding_spec.device_mesh return dim_partition_dict, device_mesh if sharding_spec is None: return None, None - assert isinstance(sharding_spec, - (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + assert isinstance( + sharding_spec, (tuple, list) + ), "sharding_spec should be type of ShardingSpec, tuple, list or None" device_mesh = sharding_spec[0].device_mesh dim_partition_dict = [] @@ -322,8 +321,9 @@ def _process_node_arguments(node): else: new_args.append(arg) else: - assert isinstance(arg, - (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + assert isinstance( + arg, (int, tuple, list) + ), "The argument in view node should be either type of Node or int." if isinstance(arg, (tuple, list)): new_args.extend(arg) else: @@ -332,7 +332,7 @@ def _process_node_arguments(node): def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): new_args = _process_node_arguments(node) - if node.op == 'call_method': + if node.op == "call_method": args_to_process = list(new_args[1:]) else: args_to_process = list(new_args) @@ -350,7 +350,7 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): args_to_process = tuple(args_to_process) - if node.op == 'call_method': + if node.op == "call_method": new_args = (new_args[0],) + args_to_process else: new_args = args_to_process @@ -358,9 +358,9 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): node.args = new_args def _filter_node_with_shape_args(node): - if node.op == 'call_method': + if node.op == "call_method": target = getattr(node.args[0]._meta_data.__class__, node.target) - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target else: target = None @@ -371,7 +371,7 @@ def _filter_node_with_shape_args(node): for node in nodes: # skip the placeholder node added in _solution_annotation pass - if not hasattr(node, 'sharding_spec'): + if not hasattr(node, "sharding_spec"): continue output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) @@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes reduction_stream = torch.cuda.Stream() def _add_hook_for_grad_communication(node, param, name=None): - comm_actions = node.best_strategy.communication_actions def _filter_param_to_hook(node, op_data, comm_action, name): - - if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "call_module" + and op_data.type == OperationDataType.PARAM + and op_data.name == name + and comm_action.comm_type == CommType.HOOK + ): return True - if node.op == 'get_attr' and isinstance( - node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "get_attr" + and isinstance(node._meta_data, torch.nn.parameter.Parameter) + and comm_action.comm_type == CommType.HOOK + ): return True return False @@ -410,7 +416,6 @@ def _filter_param_to_hook(node, op_data, comm_action, name): if _filter_param_to_hook(node, operation_data, comm_action, name=name): def wrapper(param, comm_spec, stream, overlap): - def hook_fn(grad): if overlap: with torch.cuda.stream(stream): @@ -426,22 +431,26 @@ def _shard_param(param, target_sharding_spec): # apply the sharding spec of parameters if target_sharding_spec.dim_partition_dict != {}: origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) + setattr(param, "sharding_spec", origin_sharding_spec) # TODO: build a ColoParameter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec + ) + .detach() + .clone() + ) return param for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) # TODO: we need to do more actions to take care of the shared parameters. - if hasattr(target_module, 'processed') and target_module.processed: + if hasattr(target_module, "processed") and target_module.processed: continue - setattr(target_module, 'processed', True) + setattr(target_module, "processed", True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) param = _shard_param(param, target_sharding_spec) @@ -453,7 +462,7 @@ def _shard_param(param, target_sharding_spec): # apply the sharding spec of buffers for name, buffer in target_module.named_buffers(): origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) - setattr(buffer, 'sharding_spec', origin_sharding_spec) + setattr(buffer, "sharding_spec", origin_sharding_spec) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) sharded_buffer_dict[name] = buffer_sharded @@ -461,7 +470,7 @@ def _shard_param(param, target_sharding_spec): for name, buffer_sharded in sharded_buffer_dict.items(): setattr(target_module, name, buffer_sharded.detach().clone()) - if node.op == 'get_attr': + if node.op == "get_attr": root = node.graph.owning_module atoms = node.target.split(".") attr_len = len(atoms) @@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): """ replace the origin kernel into kernel with implicit communication inside. """ - pass -def runtime_preparation_pass(gm: torch.fx.GraphModule, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap=False): +def runtime_preparation_pass( + gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap=False, +): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass( - gm, solution, strategies_constructor) + gm, solution, strategies_constructor + ) gm = size_value_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 99c124934060..e9c2c8664a61 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -3,9 +3,22 @@ import torch __all__ = [ - 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', - 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' + "ELEMENTWISE_MODULE_OP", + "ELEMENTWISE_FUNC_OP", + "RESHAPE_FUNC_OP", + "CONV_MODULE_OP", + "CONV_FUNC_OP", + "LINEAR_MODULE_OP", + "LINEAR_FUNC_OP", + "BATCHNORM_MODULE_OP", + "POOL_MODULE_OP", + "NON_PARAM_FUNC_OP", + "BCAST_FUNC_OP", + "EMBEDDING_MODULE_OP", + "LAYERNORM_MODULE_OP", + "ELEMENTWISE_METHOD_OP", + "RESHAPE_METHOD_OP", + "INFINITY_COST", ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] @@ -18,13 +31,13 @@ torch.nn.functional.relu, torch.nn.functional.dropout, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] ELEMENTWISE_METHOD_OP = [ torch.Tensor.to, torch.Tensor.type, # TODO: contiguous maybe need some extra processes. - torch.Tensor.contiguous + torch.Tensor.contiguous, ] RESHAPE_FUNC_OP = [ torch.flatten, @@ -42,15 +55,36 @@ torch.Tensor.transpose, ] BCAST_FUNC_OP = [ - torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + torch.matmul, + operator.pow, + torch.pow, ] CONV_MODULE_OP = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose3d + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, ] CONV_FUNC_OP = [ - torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, ] EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] LINEAR_MODULE_OP = [torch.nn.Linear] @@ -85,7 +119,7 @@ operator.floordiv, operator.truediv, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index b406ca6fb7e0..d82f0ef53f66 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.fx import GraphModule from torch.fx.graph import Graph from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen @@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec class ModuleWrapper(nn.Module): - ''' + """ This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict into the forward function. - ''' - - def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], - origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): - ''' + """ + + def __init__( + self, + module: ColoGraphModule, + sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], + comm_actions_dict: Dict[int, Dict[str, CommAction]], + ): + """ Args: module: the original module sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. - ''' + """ super(ModuleWrapper, self).__init__() self.module = module self.sharding_spec_dict = sharding_spec_dict @@ -42,67 +46,68 @@ def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[S self.comm_actions_dict = comm_actions_dict def forward(self, *args, **kwargs): - return self.module(*args, - sharding_spec_convert_dict=self.sharding_spec_dict, - origin_node_sharding_spec_dict=self.origin_spec_dict, - comm_actions_dict=self.comm_actions_dict, - **kwargs) + return self.module( + *args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs, + ) def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): - ''' + """ This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. - ''' + """ # TODO: implement this function - pass def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): - ''' + """ This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape from the alpha_beta_dict. These two values will be used to estimate the communication cost. - ''' + """ # TODO: implement this function - pass -def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, - shard_option: str): - ''' +def build_strategy_constructor( + graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str +): + """ This method is used to build the strategy_constructor for the given graph. After this method, each node in the graph will have a strategies_vector which is constructed by the related node handler. - ''' - if solver_preference == 'standard': + """ + if solver_preference == "standard": solver_preference = SolverPerference.STANDARD - elif solver_preference == 'tp': + elif solver_preference == "tp": solver_preference = SolverPerference.TP - elif solver_preference == 'dp': + elif solver_preference == "dp": solver_preference = SolverPerference.DP else: - raise ValueError(f'Invalid solver_preference: {solver_preference}') + raise ValueError(f"Invalid solver_preference: {solver_preference}") - if dataloader_option == 'replicated': + if dataloader_option == "replicated": dataloader_option = DataloaderOption.REPLICATED - elif dataloader_option == 'distributed': + elif dataloader_option == "distributed": dataloader_option = DataloaderOption.DISTRIBUTED else: - raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + raise ValueError(f"Invalid dataloader_option: {dataloader_option}") - if shard_option == 'standard': + if shard_option == "standard": shard_option = ShardOption.STANDARD - elif shard_option == 'shard': + elif shard_option == "shard": shard_option = ShardOption.SHARD - elif shard_option == 'shard_last_axis': + elif shard_option == "shard_last_axis": shard_option = ShardOption.SHARD_LAST_AXIS - elif shard_option == 'full_shard': + elif shard_option == "full_shard": shard_option = ShardOption.FULL_SHARD else: - raise ValueError(f'Invalid shard_option: {shard_option}') + raise ValueError(f"Invalid shard_option: {shard_option}") - solver_options = SolverOptions(solver_perference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + solver_options = SolverOptions( + solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option + ) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): - ''' + """ This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. - ''' + """ # temporarily we use all nodes as liveness list, we count the backward memory cost together with # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. # graph_analyser = GraphAnalyser(gm) @@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc return solution -def transform_to_sharded_model(gm: ColoGraphModule, - meta_args: Dict, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap: bool = False): - ''' +def transform_to_sharded_model( + gm: ColoGraphModule, + meta_args: Dict, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap: bool = False, +): + """ This method is used to transform the original graph to the sharded graph. The model parameters will be sharded according to the solution and the grad hooks will be added to the sharded graph using the runtime_preparation_pass. The communication node will be added into the graph using the runtime_apply_pass. - ''' - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, - solution, - device_mesh, - strategies_constructor, - overlap=overlap) + """ + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor, overlap=overlap + ) gm = runtime_apply_pass(gm) shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) gm.recompile() @@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule, return gm, sharding_spec_dicts -def initialize_device_mesh(world_size: int = -1, - physical_devices: List[int] = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None): - ''' +def initialize_device_mesh( + world_size: int = -1, + physical_devices: List[int] = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, +): + """ This method is used to initialize the device mesh. Args: @@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1, logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical mesh shape. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. - ''' + """ # if world_size is not set, use the world size from torch.distributed if world_size == -1: world_size = dist.get_world_size() @@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1, # extract alpha and beta values for the chosen logical mesh shape mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) - device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, - logical_mesh_id=logical_mesh_id, - mesh_alpha=mesh_alpha, - mesh_beta=mesh_beta, - init_process_group=True) + device_mesh = DeviceMesh( + physical_mesh_id=physical_mesh, + logical_mesh_id=logical_mesh_id, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True, + ) return device_mesh -def initialize_model(model: nn.Module, - meta_args: Dict[str, torch.Tensor], - device_mesh: DeviceMesh, - memory_budget: float = -1.0, - overlap: bool = False, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solution_path: str = None, - return_solution: bool = False): - ''' +def initialize_model( + model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + overlap: bool = False, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None, + return_solution: bool = False, +): + """ This method is used to initialize the sharded model which could be used as normal pytorch model. Args: @@ -246,7 +257,7 @@ def initialize_model(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. The returned solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. - ''' + """ tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) @@ -256,11 +267,13 @@ def initialize_model(model: nn.Module, shape_prop_pass(gm, *meta_args.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + strategies_constructor = build_strategy_constructor( + graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + ) if load_solver_solution: solution = torch.load(solution_path) else: @@ -268,8 +281,9 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, - overlap) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_args, solution, device_mesh, strategies_constructor, overlap + ) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) @@ -277,28 +291,30 @@ def initialize_model(model: nn.Module, solution_to_return = [] nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] for index, node in enumerate(nodes): - solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}") return model_to_return, solution_to_return else: return model_to_return -def autoparallelize(model: nn.Module, - meta_args: Dict[str, torch.Tensor] = None, - data_loader: torch.utils.data.DataLoader = None, - data_process_func: callable = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solver_solution_path: str = None, - return_solution: bool = False, - memory_budget: float = -1.0): - ''' +def autoparallelize( + model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + return_solution: bool = False, + memory_budget: float = -1.0, +): + """ This method is used to initialize the device mesh, extract the meta_args, and use them to create a sharded model. @@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. - ''' - device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, - logical_mesh_shape=logical_mesh_shape, - logical_mesh_id=logical_mesh_id) + """ + device_mesh = initialize_device_mesh( + alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id + ) if meta_args is None: meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) - rst_to_unpack = initialize_model(model, - meta_args, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option, - save_solver_solution=save_solver_solution, - load_solver_solution=load_solver_solution, - solution_path=solver_solution_path, - return_solution=return_solution, - memory_budget=memory_budget) + rst_to_unpack = initialize_model( + model, + meta_args, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget, + ) if return_solution: model, solution = rst_to_unpack diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 9903ca54e52c..aa2e5e9c40c0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -25,11 +25,33 @@ from .where_handler import WhereHandler __all__ = [ - 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', - 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', - 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', - 'SplitHandler' + "LinearFunctionHandler", + "LinearModuleHandler", + "BMMFunctionHandler", + "AddBMMFunctionHandler", + "LayerNormModuleHandler", + "BatchNormModuleHandler", + "ConvModuleHandler", + "ConvFunctionHandler", + "UnaryElementwiseHandler", + "DefaultReshapeHandler", + "PlaceholderHandler", + "OutputHandler", + "WhereHandler", + "NormPoolingHandler", + "BinaryElementwiseHandler", + "MatMulHandler", + "operator_registry", + "ADDMMFunctionHandler", + "GetItemHandler", + "GetattrHandler", + "ViewHandler", + "PermuteHandler", + "TensorConstructorHandler", + "EmbeddingModuleHandler", + "EmbeddingFunctionHandler", + "SumHandler", + "SoftmaxHandler", + "TransposeHandler", + "SplitHandler", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py index da0d199c5e05..47c654d6aa43 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -2,15 +2,13 @@ import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['ADDMMFunctionHandler'] +__all__ = ["ADDMMFunctionHandler"] @operator_registry.register(torch.addmm) @@ -30,25 +28,26 @@ def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: return data_type def get_operation_data_mapping(self) -> Dict[str, OperationData]: - # input operand input_data = self.node.args[1]._meta_data - physical_input_operand = OperationData(name=str(self.node.args[1]), - type=self._infer_op_data_type(input_data), - data=input_data) + physical_input_operand = OperationData( + name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data + ) # other operand other_data = self.node.args[2]._meta_data - physical_other_operand = OperationData(name=str(self.node.args[2]), - type=self._infer_op_data_type(other_data), - data=other_data) + physical_other_operand = OperationData( + name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data + ) # bias physical shape bias_logical_shape = self.node._meta_data.shape bias_data = self.node.args[0]._meta_data - physical_bias_operand = OperationData(name=str(self.node.args[0]), - type=self._infer_op_data_type(bias_data), - data=bias_data, - logical_shape=bias_logical_shape) + physical_bias_operand = OperationData( + name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape, + ) # output physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) @@ -57,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "input": physical_input_operand, "other": physical_other_operand, "output": physical_output, - 'bias': physical_bias_operand + "bias": physical_bias_operand, } return mapping @@ -66,26 +65,27 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm") + ) return generators def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - bias_op_data = op_data_mapping['bias'] + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index cb1bb36b7879..df4b1d6cef3f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,12 +2,12 @@ import torch -from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import BatchNormStrategyGenerator, StrategyGenerator -__all__ = ['BatchNormModuleHandler'] +__all__ = ["BatchNormModuleHandler"] @operator_registry.register(torch.nn.BatchNorm1d) @@ -27,30 +27,37 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) - physical_running_mean_operand = OperationData(name="running_mean", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_mean'], - logical_shape=self.named_buffers['running_mean'].shape) + physical_running_mean_operand = OperationData( + name="running_mean", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_mean"], + logical_shape=self.named_buffers["running_mean"].shape, + ) - physical_running_var_operand = OperationData(name="running_var", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_var'], - logical_shape=self.named_buffers['running_var'].shape) + physical_running_var_operand = OperationData( + name="running_var", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_var"], + logical_shape=self.named_buffers["running_var"].shape, + ) physical_num_batches_tracked_operand = OperationData( name="num_batches_tracked", type=OperationDataType.BUFFER, - data=self.named_buffers['num_batches_tracked'], - logical_shape=self.named_buffers['num_batches_tracked'].shape) + data=self.named_buffers["num_batches_tracked"], + logical_shape=self.named_buffers["num_batches_tracked"].shape, + ) mapping = { "input": physical_input_operand, @@ -58,12 +65,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "output": physical_output, "running_mean": physical_running_mean_operand, "running_var": physical_running_var_operand, - "num_batches_tracked": physical_num_batches_tracked_operand + "num_batches_tracked": physical_num_batches_tracked_operand, } - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index db8f0b54ddee..f8c137348353 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -4,15 +4,14 @@ from torch.fx.node import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..constants import BCAST_FUNC_OP from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator -__all__ = ['BinaryElementwiseHandler'] +__all__ = ["BinaryElementwiseHandler"] @operator_registry.register(BCAST_FUNC_OP) @@ -38,7 +37,7 @@ def _get_arg_value(idx): # The meta_data of node type argument could also possibly be a non-tensor object. if not isinstance(meta_data, torch.Tensor): assert isinstance(meta_data, (int, float)) - meta_data = torch.Tensor([meta_data]).to('meta') + meta_data = torch.Tensor([meta_data]).to("meta") non_tensor = True else: @@ -46,7 +45,7 @@ def _get_arg_value(idx): # but we can deem it as meta data # as it won't affect the strategy generation assert isinstance(self.node.args[idx], (int, float)) - meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + meta_data = torch.Tensor([self.node.args[idx]]).to("meta") non_tensor = True return meta_data, non_tensor @@ -58,24 +57,27 @@ def _get_arg_value(idx): # and filter the non-tensor op_data in post_process. self.non_tensor_list = [] # assert False - input_op_data = OperationData(name=str(self.node.args[0]), - type=_get_op_data_type(input_meta_data), - data=input_meta_data, - logical_shape=bcast_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=_get_op_data_type(other_meta_data), - data=other_meta_data, - logical_shape=bcast_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=bcast_shape) + input_op_data = OperationData( + name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape, + ) + output_op_data = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape + ) if non_tensor_input: self.non_tensor_list.append(input_op_data) if non_tensor_other: self.non_tensor_list.append(other_op_data) - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def get_strategy_generator(self) -> List[StrategyGenerator]: @@ -100,14 +102,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li logical_shape = op_data.logical_shape sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - sharding_spec, logical_shape, physical_shape) + sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs[op_data] = sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=op_data, - sharding_spec=sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec + ) strategy.communication_actions[op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index da2b733c9f7a..5c22ac7bef11 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -2,15 +2,13 @@ import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator -__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] +__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"] def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): @@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): node handler to reduce code redundancy. """ # input operand - physical_input_operand = OperationData(name=str(node.args[input_idx]), - type=OperationDataType.ARG, - data=node.args[input_idx]._meta_data) + physical_input_operand = OperationData( + name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data + ) # other operand - physical_other_operand = OperationData(name=str(node.args[other_idx]), - type=OperationDataType.ARG, - data=node.args[other_idx]._meta_data) + physical_other_operand = OperationData( + name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data + ) # output physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) @@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): if bias_idx is not None: # bias physical shape bias_logical_shape = node._meta_data.shape - physical_bias_operand = OperationData(name=str(node.args[bias_idx]), - type=OperationDataType.ARG, - data=node.args[bias_idx]._meta_data, - logical_shape=bias_logical_shape) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape, + ) + mapping["bias"] = physical_bias_operand return mapping @@ -91,20 +91,20 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - if 'bias' in op_data_mapping: - bias_op_data = op_data_mapping['bias'] + if "bias" in op_data_mapping: + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 272b1c85630a..fd7c1f837a5a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,13 +3,13 @@ import torch import torch.nn.functional as F -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import transpose_partition_dim -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator -__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] +__all__ = ["ConvModuleHandler", "ConvFunctionHandler"] @operator_registry.register(torch.nn.Conv1d) @@ -29,25 +29,29 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) logical_shape_for_weight = list(self.named_parameters["weight"].shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} if "bias" in self.named_parameters: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): @@ -77,9 +81,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -88,26 +92,30 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: data_type = OperationDataType.ARG logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py index 0c5b9f39e1fb..feb1032a6c0f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import DefaultReshapeGenerator, StrategyGenerator -__all__ = ['DefaultReshapeHandler'] +__all__ = ["DefaultReshapeHandler"] @operator_registry.register(torch.flatten) @@ -54,17 +54,15 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: input_data = self.node.args[0]._meta_data input_logical_shape = self.infer_logical_shape(input_data) - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=data_type, - data=input_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape + ) output_data = self.node._meta_data output_logical_shape = self.infer_logical_shape(output_data) - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape + ) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py index 112ee194b4ec..f29c3a0b7d5d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -12,11 +12,12 @@ from .registry import operator_registry from .strategy import EmbeddingStrategyGenerator, StrategyGenerator -__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] +__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"] -def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the embedding operation. @@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) try: # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True, + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {0: i} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) - strategy_copy.name = f'{strategy.name}_{i}' + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) # after updating, the logical shape will be replaced by the physical shape - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -125,14 +131,16 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight']) + physical_other_operand = OperationData( + name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"] + ) # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -141,10 +149,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the output will be transformed back to its original shape in self.post_process output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} @@ -157,10 +167,9 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -183,10 +192,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -194,9 +205,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data) + physical_other_operand = OperationData( + name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data + ) # Same as input, in F.embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -223,8 +234,7 @@ def post_process(self, strategy: ShardingStrategy): # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py index 53addb873d1d..dcf0a1760a2c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py @@ -4,7 +4,7 @@ from .node_handler import NodeHandler from .strategy import GetattrGenerator, StrategyGenerator -__all__ = ['GetattrHandler'] +__all__ = ["GetattrHandler"] class GetattrHandler(NodeHandler): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py index 3466e9dd9940..bd342c12eda9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -8,7 +8,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator -__all__ = ['GetItemHandler'] +__all__ = ["GetItemHandler"] @operator_registry.register(operator.getitem) @@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py index 452381169b74..ce6b20fa1d24 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import LayerNormGenerator, StrategyGenerator -__all__ = ['LayerNormModuleHandler'] +__all__ = ["LayerNormModuleHandler"] @operator_registry.register(torch.nn.LayerNorm) @@ -25,20 +25,22 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index ea541e434009..4177af4eaf71 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,24 +3,21 @@ import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import ( - check_sharding_spec_validity, - transpose_partition_dim, - update_partition_dim, -) +from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] +__all__ = ["LinearModuleHandler", "LinearFunctionHandler"] -def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, - weight_name: str) -> ShardingStrategy: +def _update_sharding_spec_for_transposed_weight_for_linear( + strategy: ShardingStrategy, weight_name: str +) -> ShardingStrategy: """ This function is a helper function used by both module node handler and function node handler. This function will convert the sharding spec for the transposed weight to the correct partition spec. @@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr # switch the dimensions of the transposed weight sharding_spec = strategy.get_sharding_spec_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name) - assert op_data.logical_shape[0] == op_data.data.shape[1] and \ - op_data.logical_shape[1] == op_data.data.shape[0], \ - "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + assert ( + op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0] + ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" dim_size = len(op_data.logical_shape) transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy -def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output should have the same sharding spec. @@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha input_dim_mapping = {0: i} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {0: i} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) - strategy_copy.name = f'{strategy.name}_{i}' + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha # after updating, the logical shape will be replaced by the physical shape input_dim_mapping = {} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -152,10 +158,13 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, - self.device_mesh, - linear_projection_type='linear', - solver_perference=self.solver_perference)) + LinearProjectionStrategyGenerator( + op_data_mapping, + self.device_mesh, + linear_projection_type="linear", + solver_perference=self.solver_perference, + ) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -163,28 +172,34 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape[::-1]) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.named_parameters is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if "bias" in self.named_parameters is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: @@ -194,14 +209,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li 2. the input and output sharding specs are updated to physical shape. """ # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight") # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -215,7 +230,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -223,10 +239,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -234,10 +252,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=self.node.args[1]._meta_data.shape[::-1]) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape physical_output = OperationData( @@ -249,27 +269,28 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, - weight_name=str(self.node.args[1])) + strategy = _update_sharding_spec_for_transposed_weight_for_linear( + strategy=strategy, weight_name=str(self.node.args[1]) + ) # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index fa51114a5c94..4fab5f7f05eb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -16,7 +16,7 @@ from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import ( BatchedMatMulStrategyGenerator, @@ -37,6 +37,7 @@ class MatMulType(Enum): MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D """ + DOT = 0 MM = 1 MV = 2 @@ -92,26 +93,26 @@ def __init__(self) -> None: def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = deepcopy(shape_mapping) - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] if len(input_shape) == 1: # if the input is a 1D tensor, 1 is prepended to its shape # and it will be removed afterwards input_shape.insert(0, 1) - self.padded_dim_mapping['input'] = -2 - self.padded_dim_mapping['output'] = -2 + self.padded_dim_mapping["input"] = -2 + self.padded_dim_mapping["output"] = -2 elif len(other_shape) == 1: # if the other is a 1D tensor, 1 is appended to its shape # and it will be removed afterwards other_shape = other_shape.append(1) - self.padded_dim_mapping['other'] = -1 - self.padded_dim_mapping['output'] = -1 + self.padded_dim_mapping["other"] = -1 + self.padded_dim_mapping["output"] = -1 return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): - input_op_data = op_data_mapping['input'] - other_op_data = op_data_mapping['other'] + op_data_mapping["input"] + op_data_mapping["other"] def _remove_padded_dim(key, strategy): op_data = op_data_mapping[key] @@ -131,7 +132,7 @@ def _remove_padded_dim(key, strategy): # compute unpadded tensor shape tensor_shape.pop(padded_dim) - assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}" # update sharding spec sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) @@ -142,15 +143,15 @@ def _remove_padded_dim(key, strategy): strategy_copy = strategy.clone() # only one of input and other will be padded - if 'input' in self.padded_dim_mapping: - _remove_padded_dim('input', strategy_copy) - _remove_padded_dim('output', strategy_copy) - elif 'other' in self.padded_dim_mapping: - _remove_padded_dim('other', strategy_copy) - _remove_padded_dim('output', strategy_copy) + if "input" in self.padded_dim_mapping: + _remove_padded_dim("input", strategy_copy) + _remove_padded_dim("output", strategy_copy) + elif "other" in self.padded_dim_mapping: + _remove_padded_dim("other", strategy_copy) + _remove_padded_dim("output", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -167,8 +168,8 @@ def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = shape_mapping.copy() # get shapes - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] # sanity check assert len(input_shape) > 1 and len(other_shape) > 1 @@ -179,16 +180,16 @@ def apply(self, shape_mapping: Dict[str, List[int]]): # store the broadcast dim info input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) - self.broadcast_dim_info['input'] = input_broadcast_dim_info - self.broadcast_dim_info['other'] = other_broadcast_dim_info + self.broadcast_dim_info["input"] = input_broadcast_dim_info + self.broadcast_dim_info["other"] = other_broadcast_dim_info # create the full logical shape input_shape = bcast_non_matrix_dims + input_shape[-2:] other_shape = bcast_non_matrix_dims + other_shape[-2:] assert len(input_shape) == len(other_shape) - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape return mapping_copy @@ -216,17 +217,18 @@ def _remove_sharding_on_broadcast_dim(key, strategy): physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( logical_sharding_spec=sharding_spec, logical_shape=sharding_spec.entire_shape, - physical_shape=tensor_shape_before_broadcast) + physical_shape=tensor_shape_before_broadcast, + ) strategy.sharding_specs[op_data] = physical_sharding_spec # enumerate all sharding strategies strategies = [] try: strategy_copy = strategy.clone() - _remove_sharding_on_broadcast_dim('input', strategy_copy) - _remove_sharding_on_broadcast_dim('other', strategy_copy) + _remove_sharding_on_broadcast_dim("input", strategy_copy) + _remove_sharding_on_broadcast_dim("other", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -241,20 +243,20 @@ def __init__(self) -> None: def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = shape_mapping.copy() - self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + self.batch_dims_before_view = list(mapping_copy["input"][:-2]) # get shapes - input_shape = shape_mapping['input'] - other_shape = shape_mapping['other'] + input_shape = shape_mapping["input"] + other_shape = shape_mapping["other"] # view to 3d tensor assert len(input_shape) >= 3 and len(other_shape) >= 3 input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] output_shape = input_shape[:2] + other_shape[2:] - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape - mapping_copy['output'] = output_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape + mapping_copy["output"] = output_shape return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): @@ -291,11 +293,11 @@ def _update_sharding_spec(key, strategy, physical_batch_dim): # create a new strategy strategy_copy = strategy.clone() try: - _update_sharding_spec('input', strategy_copy, i) - _update_sharding_spec('other', strategy_copy, i) - _update_sharding_spec('output', strategy_copy, i) + _update_sharding_spec("input", strategy_copy, i) + _update_sharding_spec("other", strategy_copy, i) + _update_sharding_spec("output", strategy_copy, i) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: continue return strategies @@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms): 3. reshape to 3 dimensions """ - shape_mapping = {'input': input_shape, 'other': other_shape} + shape_mapping = {"input": input_shape, "other": other_shape} for transform in transforms: shape_mapping = transform.apply(shape_mapping) - input_shape = shape_mapping.get('input', None) - other_shape = shape_mapping.get('other', None) - output_shape = shape_mapping.get('output', None) + input_shape = shape_mapping.get("input", None) + other_shape = shape_mapping.get("other", None) + output_shape = shape_mapping.get("output", None) return input_shape, other_shape, output_shape @@ -364,7 +366,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) elif self.matmul_type == MatMulType.MM: generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -372,7 +375,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: MatMulType.DOT: self._get_logical_shape_for_dot, MatMulType.MM: self._get_logical_shape_for_mm, MatMulType.MV: self._get_logical_shape_for_mv, - MatMulType.BMM: self._get_logical_shape_for_bmm + MatMulType.BMM: self._get_logical_shape_for_bmm, } logical_shapes = logical_shape_func[self.matmul_type]() op_data_mapping = self._get_op_data_mapping(*logical_shapes) @@ -390,20 +393,26 @@ def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_ output_logical_shape = torch.Size(output_logical_shape) # create op data - input_op_data = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.input_meta_data, - logical_shape=input_logical_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.other_meta_data, - logical_shape=other_logical_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=self.output_meta_data, - logical_shape=output_logical_shape) - - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + input_op_data = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape, + ) + output_op_data = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def _get_logical_shape_for_dot(self): @@ -460,9 +469,11 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li dim_partition_dict[0] = shard # re-init the sharding spec - input_sharding_spec.__init__(input_sharding_spec.device_mesh, - entire_shape=input_physical_shape, - dim_partition_dict=dim_partition_dict) + input_sharding_spec.__init__( + input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict, + ) return strategy else: return strategy @@ -481,7 +492,8 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li recovered_stragies.extend(output) else: raise TypeError( - f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + f"Found unexpected output type {type(output)} from the recover method of BmmTransform" + ) strategies = recovered_stragies for index, strategies in enumerate(strategies): strategies.name = f"{strategies.name}_{index}" diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index b4b7b0e794d1..d2bad39dcbb9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -8,7 +8,6 @@ from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, - OperationDataType, ShardingSpec, ShardingStrategy, StrategiesVector, @@ -23,21 +22,23 @@ class NodeHandler(ABC): - ''' + """ The NodeHandler is an abstract class used to generate every possible strategies for an operator node. Args: node (Node): the input node in node argument list. device_mesh (DeviceMesh): A logical view of a physical mesh. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - ''' - - def __init__(self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - shard_option: ShardOption = ShardOption.STANDARD, - solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: + """ + + def __init__( + self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD, + ) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) @@ -68,8 +69,9 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None: current_sharding_spec = strategy.sharding_specs[op_data] # get the sharding specs for this node generated # in its own node handler - assert hasattr(node, 'strategies_vector'), \ - f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + assert hasattr( + node, "strategies_vector" + ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost." prev_strategy_vector = node.strategies_vector prev_sharding_specs = [ prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector @@ -80,10 +82,10 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None: resharding_costs[node] = [] def _compute_resharding_cost( - prev_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], - data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + ) -> TrainCycleItem: """ This is a helper function to compute the resharding cost for a specific strategy of a node. """ @@ -94,30 +96,35 @@ def _compute_resharding_cost( dtype = data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() _, _, consistency_cost = shape_consistency_manager.shape_consistency( - prev_sharding_spec, current_sharding_spec) - - resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, - bwd=consistency_cost["backward"] * size_per_elem_bytes, - total=consistency_cost["total"] * size_per_elem_bytes) + prev_sharding_spec, current_sharding_spec + ) + + resharding_cost = TrainCycleItem( + fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes, + ) return resharding_cost else: # This raise is used to check if we have missed any type of data. # It could be merged into Parameter branch, which means we won't handle # non-tensor arguments. - raise ValueError(f'Unsupported data type {type(data)}') + raise ValueError(f"Unsupported data type {type(data)}") else: - assert isinstance(prev_sharding_spec, (tuple, list)), \ - f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ - or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + assert isinstance( + prev_sharding_spec, (tuple, list) + ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}" fwd_cost = 0 bwd_cost = 0 total_cost = 0 - for index, (prev_sharding_spec_item, - current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, - current_sharding_spec)): - item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, - data[index]) + for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate( + zip(prev_sharding_spec, current_sharding_spec) + ): + item_cost = _compute_resharding_cost( + prev_sharding_spec_item, current_sharding_spec_item, data[index] + ) fwd_cost += item_cost.fwd bwd_cost += item_cost.bwd total_cost += item_cost.total @@ -138,17 +145,17 @@ def get_target_function(self) -> callable: This function is used to get the target function for the node handler. The target function is used to analyze the costs of strategies. """ - if self.node.op in ('placeholder', 'get_attr', 'output'): + if self.node.op in ("placeholder", "get_attr", "output"): return None - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.graph.owning_module.get_submodule(self.node.target) - elif self.node.op == 'call_function': + elif self.node.op == "call_function": target = self.node.target - elif self.node.op == 'call_method': + elif self.node.op == "call_method": target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) else: - raise ValueError(f'Unsupported node type: {self.node.op}') + raise ValueError(f"Unsupported node type: {self.node.op}") return target @@ -221,7 +228,6 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: """ Define which generators should be used by this NodeHandler object. """ - pass @abstractmethod def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -244,7 +250,6 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), } """ - pass class MetaInfoNodeHandler(NodeHandler): @@ -278,19 +283,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet, ') + logger.warning(f"The target function {target} is not patched yet, ") return self.strategies_vector class ModuleHandler(NodeHandler): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # set attributes to access module parameters for convenience - assert self.node.graph.owning_module is not None, \ - f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + assert ( + self.node.graph.owning_module is not None + ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object." module = self.node.graph.owning_module.get_submodule(self.node.target) named_parameters = list(module.named_parameters(recurse=False)) named_buffers = list(module.named_buffers(recurse=False)) @@ -333,6 +338,6 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet') + logger.warning(f"The target function {target} is not patched yet") return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 4e71ccba95a7..facf19560596 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import NormalPoolStrategyGenerator, StrategyGenerator -__all__ = ['NormPoolingHandler'] +__all__ = ["NormPoolingHandler"] @operator_registry.register(torch.nn.MaxPool1d) @@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index ed120a8c3d6d..89906a205e87 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -8,7 +8,7 @@ from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator -__all__ = ['OutputHandler'] +__all__ = ["OutputHandler"] class OutputHandler(NodeHandler): @@ -16,8 +16,9 @@ class OutputHandler(NodeHandler): A OutputHandler which deals with the sharding strategies for Output Node. """ - def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - output_option: str) -> None: + def __init__( + self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.output_option = output_option @@ -35,11 +36,11 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: for index, input_node in enumerate(self.predecessor_node): input_meta_data = input_node._meta_data physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) - name_key = f'input_{index}' + name_key = f"input_{index}" mapping[name_key] = physical_inputs output_meta_data.append(input_meta_data) - assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + assert len(output_meta_data) > 0, f"Output node {self.node} has no input node." if len(output_meta_data) == 1: output_meta_data = output_meta_data[0] else: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py index 91e4a5105a08..75f07168e47b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import PermuteGenerator, StrategyGenerator -__all__ = ['PermuteHandler'] +__all__ = ["PermuteHandler"] @operator_registry.register(torch.Tensor.permute) @@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) permute_dims = [] - if self.node.op == 'call_method': + if self.node.op == "call_method": # torch.Tensor.permute (input, *dims) for arg in self.node.args: if isinstance(arg, torch.fx.Node): if isinstance(arg._meta_data, int): permute_dims.append(arg._meta_data) else: - assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + assert isinstance(arg, int), "The argument in permute node should be either type of Node or int." permute_dims.append(arg) else: # torch.permute (input, dims) @@ -51,8 +51,8 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: permute_dims.extend(arg._meta_data) else: assert isinstance( - arg, - (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + arg, (tuple, list) + ), "The argument in permute node should be type of Node, Tuple[int] or List[int]." permute_dims.extend(arg) num_dims = self.node._meta_data.dim() @@ -61,7 +61,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: if permute_dims[i] < 0: permute_dims[i] += num_dims - physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims)) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -69,7 +69,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "permute_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index e4f40fc935a4..461bc2935780 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -8,7 +8,7 @@ from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator -__all__ = ['PlaceholderHandler'] +__all__ = ["PlaceholderHandler"] class PlaceholderHandler(NodeHandler): @@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler): A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. """ - def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - placeholder_option: str) -> None: + def __init__( + self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.placeholder_option = placeholder_option @@ -25,7 +26,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py index 730a90d74cf8..f663fc9695d3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -1,11 +1,9 @@ class Registry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -18,7 +16,7 @@ def wrapper(func): return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -26,4 +24,4 @@ def has(self, source): return source in self.store -operator_registry = Registry('operator') +operator_registry = Registry("operator") diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py index 743a1f90eaaf..6e883ea64736 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import SoftmaxGenerator, StrategyGenerator -__all__ = ['SoftmaxHandler'] +__all__ = ["SoftmaxHandler"] @operator_registry.register(torch.nn.Softmax) @@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: input_data = self.node.args[0]._meta_data physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) - softmax_dim = self.node.kwargs['dim'] + softmax_dim = self.node.kwargs["dim"] num_dims = self.node.args[0]._meta_data.dim() # recover negative value to positive if softmax_dim < 0: softmax_dim += num_dims - physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -49,7 +49,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "softmax_dim": physical_dim_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py index 653d158b7c36..4c32529a5d5b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import SplitGenerator, StrategyGenerator -__all__ = ['SplitHandler'] +__all__ = ["SplitHandler"] @operator_registry.register(torch.Tensor.split) @@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: split_dim = self.node.args[2] else: if self.node.kwargs: - split_dim = self.node.kwargs['dim'] + split_dim = self.node.kwargs["dim"] else: split_dim = 0 @@ -48,7 +48,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: split_dim += num_dims split_info = (split_size, split_dim) - physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -56,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "split_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index db1f31521c86..1fc7f613716b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -29,11 +29,31 @@ from .where_generator import WhereGenerator __all__ = [ - 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', - 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', - 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', - 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator', - 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator', - 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator', - 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator' + "StrategyGenerator", + "DotProductStrategyGenerator", + "MatVecStrategyGenerator", + "LinearProjectionStrategyGenerator", + "BatchedMatMulStrategyGenerator", + "ConvStrategyGenerator", + "UnaryElementwiseGenerator", + "BatchNormStrategyGenerator", + "GetItemStrategyGenerator", + "TensorStrategyGenerator", + "TensorTupleStrategyGenerator", + "LayerNormGenerator", + "PlaceholderGenerator", + "OutputGenerator", + "WhereGenerator", + "NormalPoolStrategyGenerator", + "BinaryElementwiseStrategyGenerator", + "GetattrGenerator", + "TensorConstructorGenerator", + "EmbeddingStrategyGenerator", + "SumGenerator", + "SoftmaxGenerator", + "ViewGenerator", + "PermuteGenerator", + "TransposeGenerator", + "SplitGenerator", + "DefaultReshapeGenerator", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 416dc9c29cad..9c766b1014c8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -14,7 +14,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['BatchNormStrategyGenerator'] +__all__ = ["BatchNormStrategyGenerator"] class BatchNormStrategyGenerator(StrategyGenerator): @@ -30,28 +30,31 @@ class BatchNormStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For BatchNorm1d, the dim of input data should be 3([N, C, L]). For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: a constant coefficient need to be added. # 1D: (L) * N * Cin # 2D: (H * W) * N * Cin # 3D: (H * W * D) * N * Cin - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -69,23 +72,24 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output"), - 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"), - 'running_var': self._compute_size_in_bytes(strategy, "running_var"), + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), + "running_mean": self._compute_size_in_bytes(strategy, "running_mean"), + "running_var": self._compute_size_in_bytes(strategy, "running_var"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") # compute fwd cost incurred # fwd_cost = input + other + bias + output fwd_activation_cost = sum( - [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)]) fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost) @@ -93,36 +97,29 @@ def update_memory_cost(self, strategy: ShardingStrategy): # compute bwd cost incurred # bwd_cost = input_grad + other_grad + bias_grad bwd_activation_cost = sum( - [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost, - buffer=fwd_buffer_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost, + buffer=fwd_buffer_cost, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0] - }, - "output": { - 1: [mesh_dim_0] - }, - "running_mean": { - 0: [mesh_dim_0] - }, - "running_var": { - 0: [mesh_dim_0] - }, + "input": {1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0]}, + "output": {1: [mesh_dim_0]}, + "running_mean": {0: [mesh_dim_0]}, + "running_var": {0: [mesh_dim_0]}, "num_batches_tracked": {}, } if self.has_bias: @@ -132,29 +129,21 @@ def split_input_channel(self, mesh_dim_0): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "output": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "running_mean": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "running_var": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, + "output": {1: [mesh_dim_0, mesh_dim_1]}, + "running_mean": {0: [mesh_dim_0, mesh_dim_1]}, + "running_var": {0: [mesh_dim_0, mesh_dim_1]}, "num_batches_tracked": {}, } if self.has_bias: @@ -164,13 +153,15 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -186,21 +177,19 @@ def non_split(self): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -218,27 +207,26 @@ def split_input_batch(self, mesh_dim_0): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -256,19 +244,22 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], @@ -304,20 +295,23 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] # RS = RS x S diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py index d27cc046eaf3..c7da0034ec3b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -14,7 +14,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['BinaryElementwiseStrategyGenerator'] +__all__ = ["BinaryElementwiseStrategyGenerator"] class BinaryElementwiseStrategyGenerator(StrategyGenerator): @@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - assert len(self.op_data) == 3, \ - f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + assert ( + len(self.op_data) == 3 + ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}" for name, op_data in self.op_data.items(): if not isinstance(op_data.data, (torch.Tensor, int, float)): - raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.") def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # since elementwise ops are not compute-intensive, # we approximate the backward compute cost # to be twice the fwd compute cost fwd_compute_cost = reduce(operator.mul, shape) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # all input, output and outputs have the same shape - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # compute fwd memory cost in bytes # as the elementwise ops are not memory-intensive # we approximate the fwd memory cost to be the output # and the backward memory cost to be grad of input and other - input_bytes = self._compute_size_in_bytes(strategy, 'input') - other_bytes = self._compute_size_in_bytes(strategy, 'other') - output_bytes = self._compute_size_in_bytes(strategy, 'output') + input_bytes = self._compute_size_in_bytes(strategy, "input") + other_bytes = self._compute_size_in_bytes(strategy, "other") + output_bytes = self._compute_size_in_bytes(strategy, "output") fwd_memory_cost = MemoryCost(activation=output_bytes) bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) @@ -66,7 +67,7 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -86,21 +87,22 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # convert these dim partition dict to sharding strategy for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = dict(input=dim_partition_dict, - other=dim_partition_dict, - output=dim_partition_dict) + dim_partition_dict_mapping = dict( + input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict + ) try: sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) communication_action_mapping = {} # get name - sharding_seq = sharding_spec_mapping['input'].sharding_sequence - name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_seq = sharding_spec_mapping["input"].sharding_sequence + name = f"{sharding_seq} = {sharding_seq} {sharding_seq}" sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index e605a68a326b..5208f61543bb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Conv1d, the dim of input data should be 3([N, C, L]). For Conv2d, the dim of input data should be 4([N, C, H, W]). For Conv3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (L) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -76,14 +77,14 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -100,26 +101,20 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.has_bias: dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} @@ -132,7 +127,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -140,7 +136,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -148,38 +145,41 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -196,7 +196,8 @@ def split_input_batch(self, mesh_dim_0): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -204,42 +205,45 @@ def split_input_batch(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], 1: [mesh_dim_1], }, - "other": { - 0: [mesh_dim_1] - }, + "other": {0: [mesh_dim_1]}, "output": { 0: [mesh_dim_0], }, @@ -254,7 +258,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} @@ -263,7 +268,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -271,7 +277,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: if self.is_param("bias"): @@ -279,23 +286,27 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -322,23 +333,27 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_in_channel(self, mesh_dim_0): - name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R" dim_partition_dict_mapping = { "input": { @@ -360,17 +375,20 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_weight_out_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -395,17 +413,20 @@ def split_weight_out_channel(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" dim_partition_dict_mapping = { "input": {}, @@ -418,13 +439,13 @@ def non_split(self): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" dim_partition_dict_mapping = { "input": { @@ -447,14 +468,16 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action @@ -464,23 +487,27 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 1: [mesh_dim_0, mesh_dim_1], @@ -501,17 +528,20 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, "other": { @@ -535,13 +565,16 @@ def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py index 82a04ab52e73..385a8886f231 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -27,16 +25,16 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: The computation cost for the embedding handler is estimated as dense computing now. It may not be accurate. - ''' + """ # TODO: estimate the embedding computation cost as sparse operation - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) other_size_product = reduce(operator.mul, sharded_other_shape) @@ -55,9 +53,9 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -75,14 +73,15 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def non_split(self): - name = f'RR = R x RR' + name = f"RR = R x RR" dim_partition_dict_mapping = { "input": {}, @@ -92,18 +91,16 @@ def non_split(self): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -118,7 +115,8 @@ def split_input(self, mesh_dim_0): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -126,17 +124,20 @@ def split_input(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -159,7 +160,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -167,7 +169,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -175,22 +178,23 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "output": { 0: [mesh_dim_0, mesh_dim_1], @@ -207,7 +211,8 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -215,17 +220,20 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_embedding_dim(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -245,17 +253,20 @@ def split_embedding_dim(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, @@ -275,13 +286,16 @@ def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py index bbeb9a639c83..cc8d5771f28e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['GetattrGenerator'] +__all__ = ["GetattrGenerator"] class GetattrGenerator(StrategyGenerator): @@ -26,10 +26,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -47,7 +47,7 @@ def update_memory_cost(self, strategy: ShardingStrategy): def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -78,7 +78,8 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 0aeb2e0d4079..6f01d9cc7f8e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,19 +1,13 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.logging import get_dist_logger -from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import FollowingStrategyGenerator -__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] +__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"] class GetItemStrategyGenerator(FollowingStrategyGenerator): @@ -35,12 +29,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -58,27 +52,29 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class TensorStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 1 and 2. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - getitem_index = self.op_data['index'].data + getitem_index = self.op_data["index"].data for index, strategy in enumerate(self.predecessor_node.strategies_vector): try: logger = get_dist_logger() dim_partition_dict_mapping = {} communication_action_mapping = {} dim_partition_dict_for_input = copy.deepcopy( - strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict + ) int_index = False if isinstance(getitem_index, int): @@ -120,9 +116,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) except ShardingSpecException as e: logger.debug(e) continue @@ -137,9 +135,9 @@ def collate_strategies(self) -> List[ShardingStrategy]: class TensorTupleStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 3. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -158,13 +156,15 @@ def collate_strategies(self) -> List[ShardingStrategy]: sharding_spec_mapping["input"] = sharding_spec_for_input input_sharding_info = f"get the {index} element from (" for sharding_spec in sharding_spec_for_input: - input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += f"{sharding_spec.sharding_sequence}, " input_sharding_info += ")" name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 65b173bbf65d..e5b7e6f25d4d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -18,7 +18,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['LayerNormGenerator'] +__all__ = ["LayerNormGenerator"] class LayerNormGenerator(StrategyGenerator): @@ -31,21 +31,21 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # TODO: a constant coefficient need to be added. - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_weight_shape) # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. - input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)] + input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)] input_batch_product = reduce(operator.mul, input_batch_shape, 1) norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1) forward_compute_cost = input_batch_product * norm_kernel_product @@ -62,18 +62,18 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -90,8 +90,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -120,7 +121,8 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: @@ -128,12 +130,15 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["bias"] = bias_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -155,7 +160,7 @@ def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimensio @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -168,14 +173,16 @@ def non_split(self): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] input_data_dim = len(self.op_data["input"].logical_shape) weight_data_dim = len(self.op_data["other"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index aa1581b99e0f..fb182afb9175 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,5 +1,4 @@ import operator -from ast import arg from functools import reduce from typing import List @@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - size_mapping['bias'] = bias_size + size_mapping["bias"] = bias_size # compute fwd cost incurred # fwd_cost = input + other + bias + output @@ -41,45 +40,47 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # compute bwd cost incurred # bwd_cost = input_grad + bias_grad - bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + 0) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0 + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class DotProductStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception def no_split(self): - name = f'R = R dot R' - dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + name = f"R = R dot R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_one_dim(self, mesh_dim): - name = f'R = S{mesh_dim} dot S{mesh_dim}' + name = f"R = S{mesh_dim} dot S{mesh_dim}" # get sharding spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} @@ -87,14 +88,17 @@ def split_one_dim(self, mesh_dim): # get communication action output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -112,19 +116,18 @@ def collate_strategies(self) -> List[ShardingStrategy]: class MatVecStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception @@ -133,67 +136,69 @@ def no_split(self): dim_partition_dict = {"input": {}, "other": {}, "output": {}} if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim): - name = f'S{mesh_dim}R = S{mesh_dim}R x R' + name = f"S{mesh_dim}R = S{mesh_dim}R x R" # get sharding spec dim_partition_dict = { - "input": { - 0: [mesh_dim] - }, + "input": {0: [mesh_dim]}, "other": {}, - "output": { - 0: [mesh_dim] - }, + "output": {0: [mesh_dim]}, } if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=2) - communication_action_mapping['bias'] = bias_comm_action + arg_index=2, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -209,12 +214,13 @@ def collate_strategies(self) -> List[ShardingStrategy]: class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - - def __init__(self, - operation_data_mapping, - device_mesh, - linear_projection_type='linear', - solver_perference=SolverPerference.STANDARD): + def __init__( + self, + operation_data_mapping, + device_mesh, + linear_projection_type="linear", + solver_perference=SolverPerference.STANDARD, + ): super().__init__(operation_data_mapping, device_mesh) self.linear_projection_type = linear_projection_type self.solver_perference = solver_perference @@ -224,17 +230,17 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C: [M, N], A: [M, P], B: [P, N] # fwd cost = MNP (only count mul) # bwd: 2 x fwd_cost - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) dim_n_val = sharded_other_shape[-1] dim_p_val = sharded_other_shape[0] fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=bwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def dp_strategies(self) -> List[ShardingStrategy]: @@ -301,28 +307,21 @@ def collate_strategies(self) -> List[ShardingStrategy]: @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - -1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {-1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], -1: [mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -333,75 +332,75 @@ def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['input'] = input_comm_action - communication_action_mapping['other'] = other_comm_action + communication_action_mapping["input"] = input_comm_action + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" # get sharding spec mapping dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "other": {0: [mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -412,66 +411,64 @@ def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['other'] = other_comm_action - communication_action_mapping['output'] = output_comm_action + communication_action_mapping["other"] = other_comm_action + communication_action_mapping["output"] = output_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" # get sharding specs dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - -1: [mesh_dim_1] - }, + "input": {-1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "bias": {-1: [mesh_dim_1]}, + "output": {-1: [mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -482,34 +479,34 @@ def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping["input"] = input_comm_action - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def recompute_split_both_contract(self, mesh_dim): - name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + name = f"RR = RS{mesh_dim} x S{mesh_dim}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim] - }, - "other": { - 0: [mesh_dim] - }, + "input": {-1: [mesh_dim]}, + "other": {0: [mesh_dim]}, "bias": {}, "output": {}, } @@ -520,32 +517,29 @@ def recompute_split_both_contract(self, mesh_dim): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_only(self, mesh_dim): - name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + name = f"RS{mesh_dim} = RR x RS{mesh_dim}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim] - }, - "bias": { - -1: [mesh_dim] - }, - "output": { - -1: [mesh_dim] - }, + "other": {-1: [mesh_dim]}, + "bias": {-1: [mesh_dim]}, + "output": {-1: [mesh_dim]}, } # We don't have to do anything special for bias here, because # the bias is already the same sharding spec as the output. @@ -554,93 +548,94 @@ def split_rhs_space_only(self, mesh_dim): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - communication_action_mapping['input'] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["input"] = input_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" # get sharding spec dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {-1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, "output": {}, } @@ -652,32 +647,29 @@ def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "output": { - -1: [mesh_dim_0, mesh_dim_1] - }, + "other": {-1: [mesh_dim_0, mesh_dim_1]}, + "bias": {-1: [mesh_dim_0, mesh_dim_1]}, + "output": {-1: [mesh_dim_0, mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -687,20 +679,23 @@ def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): # get communication action communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" # get sharding spec dim_partition_dict_mapping = { @@ -717,22 +712,24 @@ def non_split(self): # get communication action communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def validate(self) -> bool: assert "input" in self.op_data assert "other" in self.op_data # make sure the other has 2 dim - input_data = self.op_data['input'] - other_data = self.op_data['other'] + input_data = self.op_data["input"] + other_data = self.op_data["other"] assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] if self.has_bias: - bias_data = self.op_data['bias'] + bias_data = self.op_data["bias"] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] @@ -757,37 +754,38 @@ def __init__(self, *args, **kwargs): def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): # remove partition dict for dim 0 - dim_partition_dict['output'].pop(0, None) + dim_partition_dict["output"].pop(0, None) # decrease the remaining dim index by 1 temp_dim_partition = {} - keys = list(dim_partition_dict['output'].keys()) + keys = list(dim_partition_dict["output"].keys()) for key in keys: - val = dim_partition_dict['output'].pop(key) + val = dim_partition_dict["output"].pop(key) temp_dim_partition[key - 1] = val - dim_partition_dict['output'].update(temp_dim_partition) + dim_partition_dict["output"].update(temp_dim_partition) def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 - if 'bias' in self.op_data: - bias_op_data = self.op_data['bias'] + if "bias" in self.op_data: + bias_op_data = self.op_data["bias"] assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, - self.op_data['output'].data.shape) + fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce( + operator.mul, self.op_data["output"].data.shape + ) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost @ignore_sharding_exception def split_one_batch_dim(self, mesh_dim): - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}" # get sharding_spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} @@ -799,30 +797,27 @@ def split_one_batch_dim(self, mesh_dim): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - } + "output": {0: [mesh_dim_0, mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -832,35 +827,28 @@ def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0] - }, - "bias": { - 0: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - } + "input": {0: [mesh_dim_0], 1: [mesh_dim_1]}, + "other": {0: [mesh_dim_0]}, + "bias": {0: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -869,46 +857,40 @@ def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the third argument instead of second. - communication_action_mapping['other'].arg_index += 1 + communication_action_mapping["other"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "bias": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - } + "input": {0: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "bias": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 2: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -917,43 +899,41 @@ def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.BEFORE) - communication_action_mapping['bias'] = bias_comm_action + comm_type=CommType.BEFORE, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the second argument instead of first. - communication_action_mapping['input'].arg_index += 1 + communication_action_mapping["input"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "other": {0: [mesh_dim_0], 1: [mesh_dim_1]}, "bias": {}, "output": { 0: [mesh_dim_0], - } + }, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -962,24 +942,28 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action - - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index b7db42f8f67e..b307e38b5b6d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -21,28 +21,31 @@ class NormalPoolStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Pool1d, the dim of input data should be 3([N, C, L]). For Pool2d, the dim of input data should be 4([N, C, H, W]). For Pool3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (Lout) * N * C * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() kernel_size = self.op_data["other"].data if isinstance(kernel_size, int): @@ -61,8 +64,8 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -88,12 +91,16 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + name = ( + f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + ) communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index 69d1642d4f80..33fb1ac5c5be 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -12,7 +12,7 @@ from .strategy_generator import OutputStrategyGenerator -__all__ = ['OutputGenerator'] +__all__ = ["OutputGenerator"] class OutputGenerator(OutputStrategyGenerator): @@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator): OutputGenerator is a generic class to generate strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node], output_option: str): + def __init__( + self, + operation_data_mapping: Dict[str, OperationData], + device_mesh: DeviceMesh, + predecessor_nodes: List[Node], + output_option: str, + ): super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) self.output_option = output_option @@ -33,9 +38,9 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ fwd_mem_cost = MemoryCost(activation=0, parameter=0) bwd_mem_cost = MemoryCost(activation=0, parameter=0) @@ -65,16 +70,18 @@ def replica_strategy(self) -> List[ShardingStrategy]: else: dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) - dim_partition_dict_mapping['output'] = dim_partition_dict_for_output + dim_partition_dict_mapping["output"] = dim_partition_dict_for_output communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Output' + name = "Replica Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: @@ -82,19 +89,15 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi Generate distributed strategy for output node. """ # TODO: need to take care of the case when the first element of output only need to be sharded. - output_op_data = self.op_data['output'] + output_op_data = self.op_data["output"] if isinstance(output_op_data.data, tuple): length = len(output_op_data.data) dim_partition_dict_mapping = { - "output": [{ - 0: mesh_list - }] * length, + "output": [{0: mesh_list}] * length, } else: dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } for index, _ in enumerate(self.predecessor_nodes): mapping_name = f"input_{index}" @@ -103,19 +106,21 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Output' + name = "Distributed Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] mesh_list = [0, 1] - if self.output_option == 'replicated': + if self.output_option == "replicated": strategy_list.append(self.replica_strategy()) - elif self.output_option == 'distributed': + elif self.output_option == "distributed": strategy_list.append(self.distributed_strategy(mesh_list)) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 779a7ced93bb..df0862a396d2 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['PlaceholderGenerator'] +__all__ = ["PlaceholderGenerator"] class PlaceholderGenerator(StrategyGenerator): @@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator): PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - placeholder_option: str): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str + ): super().__init__(operation_data_mapping, device_mesh) self.placeholder_option = placeholder_option @@ -31,10 +32,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -58,11 +59,13 @@ def replica_placeholder(self) -> ShardingStrategy: communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Placeholder' + name = "Replica Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -71,29 +74,31 @@ def distributed_placeholder(self, mesh_list) -> ShardingStrategy: Generate distributed strategy for placeholder node. """ dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Placeholder' + name = "Distributed Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - if self.placeholder_option == 'distributed': + if self.placeholder_option == "distributed": mesh_list = [0, 1] distributed_strategy = self.distributed_placeholder(mesh_list) strategy_list.append(distributed_strategy) else: - assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + assert ( + self.placeholder_option == "replicated" + ), f"placeholder_option {self.placeholder_option} is not supported" replicated_strategy = self.replica_placeholder() strategy_list.append(replicated_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 24f75e352935..48f454553ac7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -17,7 +17,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] +__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"] class ReshapeGenerator(FollowingStrategyGenerator): @@ -33,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -56,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -77,8 +78,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - origin_shape = self.op_data['input'].data.shape - tgt_shape = self.op_data['tgt_shape'].data + origin_shape = self.op_data["input"].data.shape + tgt_shape = self.op_data["tgt_shape"].data reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) @@ -86,8 +87,9 @@ def collate_strategies(self) -> List[ShardingStrategy]: keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) if keep_sharding_status: - dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, - reshape_mapping_dict) + dim_partition_dict_for_output = infer_output_dim_partition_dict( + dim_partition_dict_for_input, reshape_mapping_dict + ) else: dim_partition_dict_for_output = {} @@ -119,7 +121,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = shard_dim # it will split the input activation grad through shard_dim during backward phase. @@ -127,10 +130,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -139,9 +142,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -159,7 +164,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - permute_dims = self.op_data['permute_dims'].data + permute_dims = self.op_data["permute_dims"].data dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} for dim_index, permute_dim in enumerate(permute_dims): @@ -177,9 +182,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -199,7 +206,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} - transpose_dims = self.op_data['transpose_dims'].data + transpose_dims = self.op_data["transpose_dims"].data dim_0 = transpose_dims[0] dim_1 = transpose_dims[1] for dim, sharded_dims in dim_partition_dict_for_input.items(): @@ -221,9 +228,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -242,7 +251,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - split_size, split_dim = self.op_data['split_info'].data + split_size, split_dim = self.op_data["split_info"].data if split_dim in dim_partition_dict_for_input: recover_dims = dim_partition_dict_for_input.pop(split_dim) @@ -271,7 +280,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=recover_dims, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = split_dim # it will split the input activation grad through split_dim during backward phase. @@ -282,7 +292,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: source_spec = input_sharding_spec # target sharding spec target_spec = sharding_spec_mapping["input"] - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -291,9 +301,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -341,16 +353,17 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list input_comm_action.comm_spec.shard_dim = total_mesh_dim_list elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -358,9 +371,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py index a1ebadd043e2..d4382f9941d2 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -4,21 +4,9 @@ from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern - -__all__ = ['SoftmaxGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SoftmaxGenerator"] class SoftmaxGenerator(FollowingStrategyGenerator): @@ -30,11 +18,11 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - ''' - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + """ + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) @@ -45,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -68,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -80,10 +69,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - softmax_dim = self.op_data['softmax_dim'].data + softmax_dim = self.op_data["softmax_dim"].data if softmax_dim in dim_partition_dict_for_input: - recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + dim_partition_dict_for_input.pop(softmax_dim) dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) dim_partition_dict_mapping = { @@ -96,9 +85,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index d42429745c61..7bf2c8cc12a3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -39,7 +39,7 @@ def has_bias(self): """ A utility method to check for the existence of bias operand for convenience. """ - return 'bias' in self.op_data + return "bias" in self.op_data def is_param(self, op_data_name): other_data = self.op_data[op_data_name] @@ -49,8 +49,12 @@ def is_buffer(self, op_data_name): other_data = self.op_data[op_data_name] return other_data.type == OperationDataType.BUFFER - def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], - communication_action_mapping: Dict[str, CommSpec]): + def get_sharding_strategy( + self, + name: str, + sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec], + ): """ A factory method to produce a ShardingStrategy object. @@ -80,24 +84,28 @@ def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): op_data = self.op_data[op_data_name] def _to_sharding_spec( - data: any, logical_shape: any, - dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]] + ) -> Union[ShardingSpec, List[ShardingSpec], None]: """ This is a recursive function to convert the dim partition dict to a ShardingSpec object. """ if isinstance(data, torch.Tensor): dim_size = len(logical_shape) dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=logical_shape, - dim_partition_dict=dim_partition_dict) + sharding_spec = ShardingSpec( + device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict, + ) return sharding_spec elif isinstance(data, (list, tuple)): sharding_spec = [] for data_element, logical_shape_element, dim_partition_dict_element in zip( - data, logical_shape, dim_partition_dict): + data, logical_shape, dim_partition_dict + ): sharding_spec.append( - _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element) + ) return sharding_spec else: return None @@ -116,31 +124,41 @@ def replace_op_name_with_op_data(self, mapping: Dict[str, Any]): results[op_data] = v return results - def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]]): + def get_communication_spec( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + ): """ A factory method to produce a CommSpec object. """ - return CommSpec(comm_pattern=communication_pattern, - sharding_spec=sharding_spec, - logical_process_axis=logical_process_axis) - - def get_communication_action(self, - sharding_spec: ShardingSpec, - communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]], - comm_type: CommType, - arg_index: int = -1, - key_for_kwarg: any = None) -> CommAction: + return CommSpec( + comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis + ) + + def get_communication_action( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1, + key_for_kwarg: any = None, + ) -> CommAction: """ A factory method to produce a CommAction object. """ - return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, - communication_pattern=communication_pattern, - logical_process_axis=logical_process_axis), - comm_type=comm_type, - arg_index=arg_index, - key_for_kwarg=key_for_kwarg) + return CommAction( + comm_spec=self.get_communication_spec( + sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis, + ), + comm_type=comm_type, + arg_index=arg_index, + key_for_kwarg=key_for_kwarg, + ) def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ @@ -155,9 +173,9 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() for phase, cost in num_ele_in_comm.items(): num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes - comm_cost.fwd += num_ele_in_comm['forward'] - comm_cost.bwd += num_ele_in_comm['backward'] - comm_cost.total += num_ele_in_comm['total'] + comm_cost.fwd += num_ele_in_comm["forward"] + comm_cost.bwd += num_ele_in_comm["backward"] + comm_cost.total += num_ele_in_comm["total"] # check if communication action exists # if so, loop over each action and compute the cost of each action @@ -169,8 +187,8 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): # this condition branch will be removed after all the handler updated. comm_spec = comm_action if isinstance(comm_spec, dict): - src_spec = comm_spec['src_spec'] - tgt_spec = comm_spec['tgt_spec'] + src_spec = comm_spec["src_spec"] + tgt_spec = comm_spec["tgt_spec"] shape_consistency_manager = ShapeConsistencyManager() _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) for comm_spec_ in comm_action_sequence: @@ -187,14 +205,12 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the computation flops. """ - pass @abstractmethod def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the memory cost in bytes. """ - pass def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): """ @@ -212,13 +228,14 @@ def _compute_size_in_bytes_helper(sharding_spec, meta_data): num_elements = 1 else: num_elements = reduce(operator.mul, sharded_shape) - dtype = getattr(meta_data, 'dtype') + dtype = getattr(meta_data, "dtype") size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() return num_elements * size_per_elem_bytes if isinstance(op_data.data, tuple): - assert isinstance(strategy.sharding_specs[op_data], list), \ - 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + assert isinstance( + strategy.sharding_specs[op_data], list + ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple." total_bytes = 0 for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): meta_data = op_data.data[index] @@ -270,7 +287,6 @@ def validate(self) -> bool: Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. """ - pass class FollowingStrategyGenerator(StrategyGenerator): @@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator): TODO: remove the original strategy_generator.py after refactoring """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_node: Node): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node + ): self.op_data = operation_data_mapping self.device_mesh = device_mesh self.predecessor_node = predecessor_node @@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator): OutputStrategyGenerator is used to generate the sharding strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node]): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node] + ): super().__init__(operation_data_mapping, device_mesh) self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py index a0fbc58d70c0..dcbf34cfd65b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -4,22 +4,9 @@ from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec - -__all__ = ['SumGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SumGenerator"] class SumGenerator(FollowingStrategyGenerator): @@ -31,24 +18,24 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) - compute_cost = TrainCycleItem(fwd=input_size_product, - bwd=output_size_product, - total=input_size_product + output_size_product) + compute_cost = TrainCycleItem( + fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -66,8 +53,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -78,7 +66,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + sum_dims, sum_mapping_dict = self.op_data["sum_info"].data # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce # among all the shard groups @@ -90,7 +78,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: elif dim in sum_mapping_dict: dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] else: - raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims") for dim in recover_dims: dim_partition_dict_for_input.pop(dim) @@ -105,9 +93,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py index 93cfc9eeea53..eea00c2fa064 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -1,19 +1,10 @@ -import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from .strategy_generator import StrategyGenerator -__all__ = ['TensorConstructorGenerator'] +__all__ = ["TensorConstructorGenerator"] class TensorConstructorGenerator(StrategyGenerator): @@ -30,10 +21,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = input + output @@ -57,11 +48,13 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Tensor Constructor' + name = "Replica Tensor Constructor" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index 39799a67c5a0..943cf3f1f50d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -5,7 +5,7 @@ from .strategy_generator import FollowingStrategyGenerator -__all__ = ['UnaryElementwiseGenerator'] +__all__ = ["UnaryElementwiseGenerator"] class UnaryElementwiseGenerator(FollowingStrategyGenerator): @@ -21,12 +21,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -44,8 +44,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -69,9 +70,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # we keep same strategies with different name for node merging, and it will not increase the searching space, # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index fa941f2cc51d..b27b4f3d4056 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['WhereGenerator'] +__all__ = ["WhereGenerator"] class WhereGenerator(StrategyGenerator): @@ -26,14 +26,14 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'condition': self._compute_size_in_bytes(strategy, "condition"), - 'x': self._compute_size_in_bytes(strategy, "x"), - 'y': self._compute_size_in_bytes(strategy, "y"), - 'output': self._compute_size_in_bytes(strategy, "output") + "condition": self._compute_size_in_bytes(strategy, "condition"), + "x": self._compute_size_in_bytes(strategy, "x"), + "y": self._compute_size_in_bytes(strategy, "y"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -59,7 +59,7 @@ def _generate_strategy_with_dim_partition(self, dim_partition): "condition": dim_partition, "x": dim_partition, "y": dim_partition, - "output": dim_partition + "output": dim_partition, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -67,9 +67,11 @@ def _generate_strategy_with_dim_partition(self, dim_partition): name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -84,9 +86,9 @@ def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_l return dim_partition_list def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a where node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] dimension_length = len(self.op_data["output"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py index 86f90694e060..5b4ea0afe5f8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, SumGenerator -__all__ = ['SumHandler'] +__all__ = ["SumHandler"] @operator_registry.register(torch.Tensor.sum) @@ -55,7 +55,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input sum_mapping_dict = {} - if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]: for i in range(num_dims): sum_mapping_dict.update({i: i}) else: @@ -67,7 +67,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: assert output_index == self.node._meta_data.dim() sum_info = (sum_dims, sum_mapping_dict) - physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -75,7 +75,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "sum_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py index 855a2e7612af..c2aa120e8a28 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -8,7 +8,7 @@ from .strategy import StrategyGenerator from .strategy.tensor_constructor_generator import TensorConstructorGenerator -__all__ = ['TensorConstructorHandler'] +__all__ = ["TensorConstructorHandler"] @operator_registry.register(torch.arange) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py index 7a9d37726490..b72d9812f406 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, TransposeGenerator -__all__ = ['TransposeHandler'] +__all__ = ["TransposeHandler"] @operator_registry.register(torch.Tensor.transpose) @@ -48,9 +48,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: if transpose_dims[i] < 0: transpose_dims[i] += num_dims - physical_shape_operand = OperationData(name='transpose_dims', - type=OperationDataType.ARG, - data=list(transpose_dims)) + physical_shape_operand = OperationData( + name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims) + ) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -58,7 +58,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "transpose_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 0362de780d7a..cbc873de8223 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator -__all__ = ['UnaryElementwiseHandler'] +__all__ = ["UnaryElementwiseHandler"] @operator_registry.register(torch.Tensor.to) @@ -33,9 +33,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py index 7dff89d1d7a3..56c1d10a167e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, ViewGenerator -__all__ = ['ViewHandler'] +__all__ = ["ViewHandler"] @operator_registry.register(torch.Tensor.reshape) @@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) target_shape = self.node._meta_data.shape - physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -46,7 +46,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "tgt_shape": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index 6de2aaafdd01..1856a11100b0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -1,16 +1,15 @@ import copy -import operator from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, WhereGenerator -__all__ = ['WhereHandler'] +__all__ = ["WhereHandler"] @operator_registry.register(torch.where) @@ -28,27 +27,28 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_condition_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_x_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.node.args[1]._meta_data) - physical_y_operand = OperationData(name=str(self.node.args[2]), - type=OperationDataType.ARG, - data=self.node.args[2]._meta_data) + physical_condition_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_x_operand = OperationData( + name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data + ) + physical_y_operand = OperationData( + name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_mapping = { "condition": physical_condition_operand, "x": physical_x_operand, "y": physical_y_operand, - "output": physical_output + "output": physical_output, } logical_shape_for_all = self.node._meta_data.shape logical_mapping = {} for key, physical_operand in physical_mapping.items(): - logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, - logical_shape_for_all) + logical_mapping[key] = self.convert_physical_operand_to_logical_operand( + physical_operand, logical_shape_for_all + ) return logical_mapping, physical_mapping @@ -64,7 +64,8 @@ def post_process(self, strategy: ShardingStrategy): logical_shape = logical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec, logical_shape, physical_shape) + logical_sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py index f0ea502a6f0e..e87872f39c10 100644 --- a/colossalai/auto_parallel/tensor_shard/options.py +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] +__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"] class SolverPerference(Enum): """ This enum class is to define the solver preference. """ + STANDARD = 0 DP = 1 TP = 2 @@ -25,6 +26,7 @@ class ShardOption(Enum): TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. """ + STANDARD = 0 SHARD = 1 SHARD_LAST_AXIS = 2 @@ -35,6 +37,7 @@ class DataloaderOption(Enum): """ This enum class is to define the dataloader option. """ + REPLICATED = 0 DISTRIBUTED = 1 @@ -44,6 +47,7 @@ class SolverOptions: """ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. """ + solver_perference: SolverPerference = SolverPerference.STANDARD dataloader_option: DataloaderOption = DataloaderOption.REPLICATED shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 6af927272437..8e22df64d868 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -10,7 +10,6 @@ from colossalai.tensor.sharding_spec import ShardingSpec from .constants import ( - BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_METHOD_OP, ELEMENTWISE_MODULE_OP, @@ -18,13 +17,14 @@ RESHAPE_METHOD_OP, ) -__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] +__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"] class OperationDataType(Enum): """ An operation can come from the argument list of an operator or the parameter list of a module. """ + INPUT = 0 ARG = 1 PARAM = 2 @@ -43,6 +43,7 @@ class OperationData: data (Any): the value for this data, usually it is a meta tensor. logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. """ + name: str type: OperationDataType data: Any @@ -69,13 +70,13 @@ def _infer_logical_shape(data: any): self.logical_shape = _infer_logical_shape(self.data) def __repr__(self) -> str: - return f'OperationData(name={self.name}, type={self.type})' + return f"OperationData(name={self.name}, type={self.type})" def __eq__(self, other) -> bool: return other.name == self.name def __hash__(self) -> int: - return hash(f'{self.name}') + return hash(f"{self.name}") @dataclass @@ -88,6 +89,7 @@ class TrainCycleItem: fwd (float): the item for the forward pass bwd (float): the item for the backward pass """ + fwd: Any bwd: Any total: Any @@ -104,6 +106,7 @@ class MemoryCost: temp (int): the memory cost incurred by the temporary tensors in bytes. buffer (int): the memory cost incurred by the module buffer in bytes. """ + activation: int = 0 parameter: int = 0 temp: int = 0 @@ -120,6 +123,7 @@ class CommType(Enum): HOOK: the communication action is used to do the grad all reduce. IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm """ + BEFORE = 0 AFTER = 1 HOOK = 2 @@ -137,6 +141,7 @@ class CommAction: arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, because the args of node may be changed by graph transform passes. """ + comm_spec: CommSpec = None comm_type: CommType = None arg_index: int = -1 @@ -156,6 +161,7 @@ class ShardingStrategy: memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. """ + name: str sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None compute_cost: TrainCycleItem = None @@ -200,7 +206,6 @@ def get_sharding_spec_by_name(self, name: str): raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") def clone(self): - def _deepcopy_dict_vals(data: Dict): return {k: deepcopy(v) for k, v in data.items()} @@ -209,31 +214,34 @@ def _deepcopy_dict_vals(data: Dict): # Consider the examples below: # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. - communication_actions = _deepcopy_dict_vals( - self.communication_actions) if self.communication_actions is not None else None + communication_actions = ( + _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None + ) # same reason as communication_actions resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None compute_cost = deepcopy(self.compute_cost) communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) - return ShardingStrategy(name=self.name, - sharding_specs=sharding_specs, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - communication_actions=communication_actions, - resharding_costs=resharding_costs) + return ShardingStrategy( + name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs, + ) class StrategiesVector(list): - ''' + """ Each node in fx graph will have a corresponding StrategiesVector, to store all the possible strategies of the node. Argument: node (Node): node for which the list of sharding strategies are generated. - ''' + """ def __init__(self, node: Node): super().__init__() @@ -245,7 +253,7 @@ def __init__(self, node: Node): def check_merge(self): merge_label = False - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.target root_module = self.node.graph.owning_module submod = root_module.get_submodule(target) @@ -255,7 +263,7 @@ def check_merge(self): if submod_type in ELEMENTWISE_MODULE_OP: merge_label = True - if self.node.op == 'call_function': + if self.node.op == "call_function": # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True @@ -267,7 +275,7 @@ def check_merge(self): if self.node.target in RESHAPE_FUNC_OP: merge_label = True - if self.node.op == 'call_method': + if self.node.op == "call_method": # we could merge reshape op, because their computation costs are negligible. method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) if method in RESHAPE_METHOD_OP: diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py index f9e6bd923921..b930ce80a9b9 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -3,4 +3,4 @@ from .solver import Solver from .strategies_constructor import StrategiesConstructor -__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] +__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"] diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index 1b2d3ad57407..4415d429b0c2 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -4,7 +4,7 @@ class CostGraph: - ''' + """ A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. @@ -15,7 +15,7 @@ class CostGraph: Argument: leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' + """ def __init__(self, leaf_strategies, simplify=True, forward_only=False): self.leaf_strategies = leaf_strategies @@ -39,10 +39,10 @@ def _remove_invalid_node(self, node, attr_name): target_node_list.remove(element) def _build_cost_graph(self): - ''' + """ This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be set to node. - ''' + """ self.edge_costs = {} if self.simplify: self.merge_pair = [] @@ -84,8 +84,8 @@ def _check_tensor_in_node(data): if _check_tensor_in_node(node._meta_data): children_nodes.append(node) - setattr(dst_node, 'parents', parent_nodes) - setattr(dst_node, 'children', children_nodes) + setattr(dst_node, "parents", parent_nodes) + setattr(dst_node, "children", children_nodes) if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: @@ -99,7 +99,7 @@ def get_edge_cost(self, src_node, dst_node): return self.edge_costs[(src_node, dst_node)] def merge_node(self, src_node, dst_node): - ''' + """ To merge dst_node into src_node, we need to do it in following steps: 1. For each strategy in dst_node, we need to pick an appropriate strategy @@ -119,7 +119,7 @@ def merge_node(self, src_node, dst_node): Argument: src_node(Node): The node will be merged into dst_node. dst_node(Node): The node to integrate src_node. - ''' + """ # build merge_map merge_map = {} for src_index, _ in enumerate(src_node.strategies_vector): @@ -196,7 +196,7 @@ def simplify_graph(self): if not self.simplify: return self.merge_pair.reverse() - for (src_node, dst_node) in self.merge_pair: + for src_node, dst_node in self.merge_pair: self.merge_node(src_node, dst_node) self.merge_pair.reverse() reindexing_following_dict = {} diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py index 171aa8b3399f..678965d663e4 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -7,7 +7,7 @@ from colossalai.fx.passes.utils import get_node_module -__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] +__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"] @dataclass @@ -15,6 +15,7 @@ class LiveVariable: """ LiveVariable is a data structure to store the meta information of a variable for liveness analysis. """ + name: str node: Node is_inplace: bool @@ -55,6 +56,7 @@ class LiveStage: """ LiveStage is a data structure to record the living variables at this current node. """ + name: str node: Node all_live_vars: LiveVariableVector @@ -62,7 +64,6 @@ class LiveStage: class GraphAnalyser: - def __init__(self, gm: GraphModule): self._gm = gm self._graph = gm.graph @@ -105,18 +106,18 @@ def liveness_analysis(self) -> List[LiveStage]: # detect whether the current op is an in-place op # if it is an in-place op, we would deem it as a duplicate var is_inplace = False - if node.op == 'call_function': + if node.op == "call_function": # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) - if node.kwargs.get('inplace', False): + if node.kwargs.get("inplace", False): is_inplace = True - elif node.op == 'call_module': + elif node.op == "call_module": # to check if this is an inplace op such as torch.nn.Relu(inplace=True) module = get_node_module(node) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): is_inplace = True # add the output var - meta = getattr(node, '_meta_data', None) + getattr(node, "_meta_data", None) live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) if not is_inplace: unique_live_vars.append(live_var) @@ -138,10 +139,12 @@ def liveness_analysis(self) -> List[LiveStage]: # this should be completed if we are able to trace the backward compute graph # add this stage to liveness dict - stage = LiveStage(name=node.name, - node=node, - all_live_vars=all_live_variables.copy(), - unique_live_vars=unique_live_vars.copy()) + stage = LiveStage( + name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy(), + ) # if a LiveStage is covered by another LiveStage, we just keep the larger one. replace = False for index, prev_stage in enumerate(liveness_list): diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 564c5f09220c..088d1acb5177 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -21,24 +21,25 @@ import pulp from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum except: - warnings.warn(f'please install the pulp') + warnings.warn(f"please install the pulp") -__all___ = ['Solver'] +__all___ = ["Solver"] class Solver: - - def __init__(self, - graph: Graph, - strategies_constructor: StrategiesConstructor, - cost_graph: CostGraph, - graph_analyser: GraphAnalyser = None, - memory_budget: float = -1.0, - solution_numbers: int = 1, - forward_only: bool = False, - memory_increasing_coefficient: float = 1.3, - verbose=False): - ''' + def __init__( + self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser = None, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3, + verbose=False, + ): + """ Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: graph: The computing graph to be optimized. @@ -48,7 +49,7 @@ def __init__(self, memory_budget: Memory constraint for the solution. solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' + """ self.graph = graph self.strategies_constructor = strategies_constructor self.cost_graph = cost_graph @@ -75,11 +76,11 @@ def __init__(self, self.verbose = verbose def _recover_merged_node_strategy(self): - ''' + """ During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged node. - ''' + """ for node_index, node in enumerate(self.nodes): if node.strategies_vector.check_merge(): # the merged node has only one input, and its strategies follow the input sharding strategy @@ -98,9 +99,9 @@ def _generate_node_index_dict(self) -> Dict[Node, int]: return node_index_dict def _prepare_data_for_solver(self): - ''' + """ Extract information from components for solver. - ''' + """ node_nums = len(self.leaf_strategies) memory_budget = self.memory_budget @@ -190,23 +191,40 @@ def _prepare_data_for_solver(self): # omit initial value for nodes s_init_np = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - alias_convert_costs, - s_init_np=None, - verbose=True): + return ( + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np, + self.verbose, + ) + + def _call_solver_serialized_args( + self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None, + verbose=True, + ): """ Call the solver with serialized arguments. """ @@ -235,18 +253,18 @@ def get_non_zero_index(binary_vector): s_follow = following_nodes s_alias = alias_set - E = edge_pairs.reshape((-1, 2)) # noqa + E = edge_pairs.reshape((-1, 2)) # noqa r = [] pt = 0 edge_set = set() - for (i, j) in E: + for i, j in E: prod_length = strategies_len[i] * strategies_len[j] if (i, j) in edge_set: raise ValueError(f"Duplicated edges: {(i, j)}") edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) + r.append(resharding_costs[pt : pt + prod_length]) pt += prod_length assert pt == len(resharding_costs) @@ -268,7 +286,6 @@ def get_non_zero_index(binary_vector): # L.append(liveness_set[pt:pt + length]) # pt += length # assert pt == len(liveness_set) - v = [] pt = 0 c = [] @@ -277,9 +294,9 @@ def get_non_zero_index(binary_vector): pt = 0 for i in range(node_nums): length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[pt:pt + length]) + c.append(compute_costs[pt : pt + length]) + d.append(communication_costs[pt : pt + length]) + m.append(memory_costs[pt : pt + length]) pt += length assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" @@ -319,7 +336,7 @@ def get_non_zero_index(binary_vector): e = [] num_edges = 0 map_edge_to_idx = {} - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if len(s[i]) == 1: e.append(s[j]) elif len(s[j]) == 1: @@ -340,7 +357,7 @@ def get_non_zero_index(binary_vector): ###################################### if s_init_np is not None: s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: + for idx, value, fix in s_init: for i in range(len(s[idx])): s[idx][i].setInitialValue(i == value) if fix: @@ -393,7 +410,7 @@ def get_non_zero_index(binary_vector): # (d). specified by `cat="Binary"` - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if strategies_len[i] == 1 or strategies_len[j] == 1: continue @@ -402,13 +419,13 @@ def get_non_zero_index(binary_vector): # (f) for row in range(len(s[i])): - C = len(s[j]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] # (g) for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa + R = len(s[i]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] # (h) @@ -434,7 +451,8 @@ def get_non_zero_index(binary_vector): msg = verbose time_limit = 600 assert "COIN_CMD" in pulp.listSolvers( - onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + onlyAvailable=True + ), "Please install ILP solvers by 'sudo apt install coinor-cbc'" solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) @@ -444,13 +462,13 @@ def get_non_zero_index(binary_vector): objective = pulp.value(prob.objective) objective = float(objective) if objective is not None else -1.0 if verbose: - print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" - f"Time: {time.time() - tic}") + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}") print(f"#nodes: {num_nodes}, #edges: {num_edges}") if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget. " - "Please increase the memory budget.") + raise RuntimeError( + "Cannot run the function under the given memory budget. " "Please increase the memory budget." + ) # Get and check results s_val = np.full((node_nums,), -1, dtype=np.int32) @@ -458,7 +476,7 @@ def get_non_zero_index(binary_vector): s_val[i] = get_non_zero_index(s[i]) e_val = np.full((len(E),), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): e_val[idx] = get_non_zero_index(e[idx]) i_spec_index = e_val[idx] // len(s[j]) j_spec_index = e_val[idx] % len(s[j]) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 044a8ac847ea..aa87ee9bf3db 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,11 +1,5 @@ -import builtins -import math -import operator -from copy import deepcopy -from typing import Dict, List - import torch -from torch.fx import Graph, Node +from torch.fx import Graph from colossalai.auto_parallel.tensor_shard.node_handler import ( GetattrHandler, @@ -14,13 +8,12 @@ operator_registry, ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector -from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.device.device_mesh import DeviceMesh from ..options import DataloaderOption, SolverOptions -__all__ = ['StrategiesConstructor'] +__all__ = ["StrategiesConstructor"] class StrategiesConstructor: @@ -35,7 +28,7 @@ class StrategiesConstructor: def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.device_mesh = device_mesh @@ -46,11 +39,11 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver self.alias_set = None def remove_duplicated_strategy(self, strategies_vector): - ''' + """ In build_strategies_and_cost method, we may produce some duplicated strategies. In this method, we will remove the duplicated strategies depending on the strategies name. Note that this operation is in-place. - ''' + """ name_checklist = [] remove_list = [] for strategy in strategies_vector: @@ -62,7 +55,6 @@ def remove_duplicated_strategy(self, strategies_vector): strategies_vector.remove(strategy) def generate_alias_set(self): - node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) @@ -83,7 +75,7 @@ def build_strategies_and_cost(self): """ def _check_no_strategy_for_node(node): - if node.op in ('placeholder', 'get_attr', 'output'): + if node.op in ("placeholder", "get_attr", "output"): return False def _check_no_strategy_for_data(data): @@ -102,83 +94,93 @@ def _check_no_strategy_for_data(data): if _check_no_strategy_for_node(node): self.no_strategy_nodes.append(node) - pass # placeholder node - elif node.op == 'placeholder': + elif node.op == "placeholder": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - placeholder_option = 'distributed' + placeholder_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - placeholder_option = 'replicated' - placeholder_handler = PlaceholderHandler(node, - self.device_mesh, - strategies_vector, - placeholder_option=placeholder_option) + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + placeholder_option = "replicated" + placeholder_handler = PlaceholderHandler( + node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option + ) placeholder_handler.register_strategy() # get_attr node - elif node.op == 'get_attr': - getattr_handler = GetattrHandler(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + elif node.op == "get_attr": + getattr_handler = GetattrHandler( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) getattr_handler.register_strategy() # call_module node - elif node.op == 'call_module': + elif node.op == "call_module": target = node.target submod = self.root_module.get_submodule(target) submod_type = type(submod) - handler = operator_registry.get(submod_type)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(submod_type)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_function node - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target - handler = operator_registry.get(target)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(target)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_method node - elif node.op == 'call_method': + elif node.op == "call_method": method = getattr(node.args[0]._meta_data.__class__, node.target) - handler = operator_registry.get(method)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(method)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # output node - elif node.op == 'output': + elif node.op == "output": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - output_option = 'distributed' + output_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - output_option = 'replicated' + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + output_option = "replicated" output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() self.remove_duplicated_strategy(strategies_vector) - setattr(node, 'strategies_vector', strategies_vector) + setattr(node, "strategies_vector", strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index b7fe5430bf13..d61cfd2add15 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -17,9 +17,21 @@ ) __all__ = [ - 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', - 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', - 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' + "BroadcastType", + "get_broadcast_shape", + "is_broadcastable", + "recover_sharding_spec_for_broadcast_shape", + "generate_resharding_costs", + "generate_sharding_spec", + "ignore_sharding_exception", + "check_sharding_spec_validity" "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", + "comm_actions_for_oprands", + "pytree_map", + "detect_reshape_mapping", + "check_keep_sharding_status", + "infer_output_dim_partition_dict", ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index 307348ea1eaf..99d5a0f2a942 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -14,8 +14,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', - 'comm_actions_for_oprands' + "BroadcastType", + "is_broadcastable", + "get_broadcast_shape", + "recover_sharding_spec_for_broadcast_shape", + "comm_actions_for_oprands", ] @@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: """ Compute the broadcast shape given two shapes. """ - assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' + assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable" shape1_reverse = shape1[::-1] shape2_reverse = shape2[::-1] min_common_dim = min(len(shape1), len(shape2)) @@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) - assert logical_num_dims >= physical_num_dims, \ - 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + assert ( + logical_num_dims >= physical_num_dims + ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!" # track the dim and its broadcasting type logical_dim_broadcast_info = {} @@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): return logical_dim_broadcast_info -def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, - physical_shape: torch.Size) -> ShardingSpec: +def recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size +) -> ShardingSpec: """ This function computes the sharding spec for the physical shape of a broadcast tensor. @@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe physical_dim = physical_num_dims - (logical_num_dims - shape_dim) physical_dim_partition[physical_dim] = mesh_dim - physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=physical_dim_partition) + physical_sharding_spec = ShardingSpec( + device_mesh=logical_sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=physical_dim_partition, + ) return physical_sharding_spec, removed_dims -def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, - sharding_spec: ShardingSpec) -> CommAction: +def comm_actions_for_oprands( + node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec +) -> CommAction: """ This method is used to generate communication actions for oprands which lose information during convert logical shape to physical shape. @@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera if len(removed_dims) == 1: # if list length is 1, extract element from list to avoid using flatten device mesh removed_dims = removed_dims[0] - comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - sharding_spec=sharding_spec, - logical_process_axis=removed_dims) + comm_spec = CommSpec( + comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims, + ) if op_data.type == OperationDataType.PARAM: comm_type = CommType.HOOK else: @@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera for index, arg in enumerate(node.args): if op_data.name == str(arg): arg_index = index - assert arg_index >= 0, f'op_data should be an argument of node.' + assert arg_index >= 0, f"op_data should be an argument of node." comm_action = CommAction( comm_spec=comm_spec, comm_type=comm_type, diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py index 347c10aa102d..aaca923a5eee 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/factory.py +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -14,11 +14,12 @@ from ..constants import INFINITY_COST -__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] +__all__ = ["generate_sharding_spec", "generate_resharding_costs"] -def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, - dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: +def generate_sharding_spec( + input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]] +) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict. @@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic """ if isinstance(input_, Node): - assert hasattr(input_, '_meta_data'), f'The given node has no attribute _meta_data' + assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data" meta_tensor = input_._meta_data assert meta_tensor is not None, "The given node's _meta_data attribute is None" shape = meta_tensor.shape @@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic shape = input_.shape else: raise TypeError( - f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected." ) for dim_index, sharding_index_list in dim_partition_dict.items(): sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + assert ( + shape[dim_index] % sharding_size == 0 + ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions." sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) return sharding_spec -def generate_resharding_costs(nodes: List[Node], - sharding_specs: List[ShardingSpec], - count_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None, - index=None): - ''' +def generate_resharding_costs( + nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None, +): + """ Compute the resharding costs with this specific strategy. Argument: @@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node], sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. - ''' + """ # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() @@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node], for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec if not isinstance(input_sharding_spec, ShardingSpec): - assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected." input_sharding_spec = input_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor." try: # compute the resharding cost _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) + input_sharding_spec, input_spec + ) # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes except AssertionError as e: - warnings.warn(f'{e}') + warnings.warn(f"{e}") resharding_cost = INFINITY_COST resharding_costs[input_node].append(resharding_cost) return resharding_costs def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): - ''' + """ Find the largest repeat blocks in the graph, whose length is larger than the threshold. Args: gm (GraphModule): the graph module to be analyzed. common_length_threshold (int): the threshold of the repeat block length. - ''' + """ # graph = gm.graph def _process_args(args): new_args = [] for arg in args: - if hasattr(arg, '_meta_data'): + if hasattr(arg, "_meta_data"): meta_data = arg._meta_data else: meta_data = arg @@ -145,7 +150,7 @@ def _check_node_equal(node1, node2): return False for index, node in enumerate(node_list): - if node.op == 'call_module': + if node.op == "call_module": target = node.target submod = root_module.get_submodule(target) submod_type = type(submod) @@ -155,12 +160,12 @@ def _check_node_equal(node1, node2): new_args = _process_args(node.args) - if node.op != 'get_attr': + if node.op != "get_attr": hash_key = (node.op, target, *new_args) else: hash_key = (node.op,) - setattr(node, 'hash_key', hash_key) + setattr(node, "hash_key", hash_key) hash_value_to_node_dict = {} @@ -179,7 +184,7 @@ def _check_node_equal(node1, node2): # the comparison will be triggered if a common node appears if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: start_index_list = hash_value_to_node_dict[hash(node.hash_key)] - check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] + check_block_list = [node_list[start : start + max_common_length] for start in start_index_list] common_label = True if not _all_equal(check_block_list, _check_node_list_equal): @@ -201,6 +206,6 @@ def _check_node_equal(node1, node2): # recover common subgraph from the index common_blocks = [] for start in common_blocks_index: - common_blocks.append(node_list[start:start + max_common_length]) + common_blocks.append(node_list[start : start + max_common_length]) return common_blocks diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 475e95fc4326..42ec2a8ee428 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,12 +1,12 @@ import functools -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -__all__ = ['ignore_sharding_exception', 'pytree_map'] +__all__ = ["ignore_sharding_exception", "pytree_map"] def ignore_sharding_exception(func): @@ -48,29 +48,32 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens tensor_num_dim = tensor.dim() num_devices_in_col = sharding_spec.device_mesh.shape[0] num_devices_in_row = sharding_spec.device_mesh.shape[1] - assert sharding_len == tensor_num_dim, \ - f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + assert ( + sharding_len == tensor_num_dim + ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})." # make sure the sharding is valid for each dim for i in range(tensor_num_dim): dim_size = tensor.shape[i] dim_spec = sharding_spec.sharding_sequence[i] - if str(dim_spec).startswith('S'): - devices_str = str(dim_spec).lstrip('S') + if str(dim_spec).startswith("S"): + devices_str = str(dim_spec).lstrip("S") num_devices = 1 - if '0' in devices_str: + if "0" in devices_str: num_devices *= num_devices_in_col - if '1' in devices_str: + if "1" in devices_str: num_devices *= num_devices_in_row - assert dim_size >= num_devices and dim_size % num_devices == 0, \ - f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + assert ( + dim_size >= num_devices and dim_size % num_devices == 0 + ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices." # make sure the entire shape matches the physical tensor shape - assert sharding_spec.entire_shape == tensor.shape, \ - f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + assert ( + sharding_spec.entire_shape == tensor.shape + ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}" def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py index d0ebbd7e8b1b..329312ef797f 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -8,6 +8,7 @@ class PreviousStatus(Enum): """ This class shows the status of previous comparison. """ + RESET = 0 # ORIGIN means the dimension size of original tensor is larger in the previous comparison. ORIGIN = 1 @@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D return reshape_mapping_dict -def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: +def check_keep_sharding_status( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> bool: """ This method is used to check whether the reshape operation could implement without converting the input to fully replicated status. @@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], return True -def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: +def infer_output_dim_partition_dict( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> Dict[Tuple[int], Tuple[int]]: """ This method is used to infer the output dim partition dict for a reshape operation, given the input dim partition dict and reshape mapping dict. """ - assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ - 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + assert check_keep_sharding_status( + input_dim_partition_dict, reshape_mapping_dict + ), "we only infer output dim partition dict for the reshape operation could keep sharding spec." sharded_dims = list(input_dim_partition_dict.keys()) output_dim_partition_dict = {} for input_dims, output_dims in reshape_mapping_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py index e2ce59e0b577..b5386d599be4 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -8,8 +8,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' + "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", ] @@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - dim1 (int): the tensor dimension to switch dim2 (int): the tensor dimension to switch """ - assert len(sharding_spec.entire_shape) >= 2, \ - 'The entire_shape of the sharding spec must have at least 2 dimensions' + assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions" dim_partition_dict = sharding_spec.dim_partition_dict # transpose the dim partition @@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - return sharding_spec -def update_partition_dim(sharding_spec: ShardingSpec, - dim_mapping: Dict[int, int], - physical_shape: torch.Size, - inplace: bool = False): +def update_partition_dim( + sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False +): """ This method is used to update the partition dim dict from the logical one to the physical one. @@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec, new_dim_partition_dict[tensor_dim] = mesh_dims # update sharding spec - current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=new_dim_partition_dict) + current_sharding_spec.__init__( + device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict + ) return current_sharding_spec diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index cc98c1570b4a..9571fa2c17f0 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -9,7 +9,18 @@ AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() if AUTOCHUNK_AVAILABLE: - from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg @@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out for i in range(len(chunk_output)): shape_str = str(list(get_node_shape(chunk_output[i]))) if get_node_name(chunk_output[i]) in ["split", "unbind"]: - tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, - input_node.name) - tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) + tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % ( + shape_str, + input_node.name, + input_node.name, + ) + tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"]) tensor_str = "[" + tensor_str[:-2] + "]" context += "%s = %s; " % (chunk_output[i].name, tensor_str) else: - context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, - input_node.name, input_node.name) + context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % ( + chunk_output[i].name, + shape_str, + input_node.name, + input_node.name, + ) out_shape = get_node_shape(chunk_output[0]) chunk_shape = out_shape[chunk_output_dim[0]] @@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out return context -def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], - chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: +def _gen_loop_end( + chunk_inputs: List[Node], + chunk_non_compute_inputs: List[Node], + node_list: List[Node], + chunk_outputs_idx: int, + chunk_outputs_non_tensor: List[Node], + search_chunk: SearchChunk, +) -> str: """ Generate chunk loop end @@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape( chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] - if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None + ): chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) return body @@ -203,11 +229,12 @@ def _add_node_slice( # outputs node else: if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): - chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", - get_node_shape(chunk_node)) + chunk_slice = _gen_chunk_slice_dim( + chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node) + ) if get_node_name(chunk_node) in ["split", "unbind"]: split_chunk_slice = "" - for i in range(len(chunk_node.meta['tensor_meta'])): + for i in range(len(chunk_node.meta["tensor_meta"])): split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) split_chunk_slice = split_chunk_slice[:-2] body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) @@ -216,13 +243,15 @@ def _add_node_slice( return body -def emit_code_with_chunk(body: List[str], - nodes: Iterable[Node], - emit_node_func: Callable, - delete_unused_value_func: Callable, - search_chunk: SearchChunk, - chunk_infos: List, - eval_mem: bool = False): +def emit_code_with_chunk( + body: List[str], + nodes: Iterable[Node], + emit_node_func: Callable, + delete_unused_value_func: Callable, + search_chunk: SearchChunk, + chunk_infos: List, + eval_mem: bool = False, +): """ Emit code with chunk according to chunk_infos. @@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str], chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs @@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], chunk_infos[region_idx]["chunk_size"], - )) + ) + ) if within_chunk_region: emit_node_func(node, body) @@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) else: emit_node_func(node, body) if node_idx not in chunk_inputs: @@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) # generate chunk region end if node_idx in chunk_ends: body.append( - _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, - chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) + _gen_loop_end( + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + node_list, + chunk_ends[region_idx], + chunk_outputs_non_tensor[region_idx], + search_chunk, + ) + ) within_chunk_region = False node_idx += 1 @@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str], if AUTOCHUNK_AVAILABLE: class AutoChunkCodeGen(CodeGen): - - def __init__(self, - meta_graph, - max_memory: int = None, - print_mem: bool = False, - print_progress: bool = False, - eval_mem: bool = False) -> None: + def __init__( + self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + eval_mem: bool = False, + ) -> None: super().__init__() self.eval_mem = eval_mem # find the chunk regions @@ -349,7 +389,7 @@ def add_global(name_hint: str, obj: Any): Returns: the global name that should be used to reference 'obj' in generated source. """ - if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -402,7 +442,6 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. if isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -457,10 +496,10 @@ def delete_unused_values(user: Node, body, to_keep=[]): # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") raw_name = node.target.replace("*", "") if raw_name != repr(node): @@ -470,42 +509,56 @@ def emit_node(node: Node, body): assert isinstance(node.target, str) body.append( f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})") + f"({_format_args(node.args[1:], node.kwargs)})" + ) return elif node.op == "call_function": assert callable(node.target) # pretty print operators - if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): - body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) - and node.args[1].isidentifier() and len(node.args) == 2): + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return elif node.op == "get_attr": assert isinstance(node.target, str) @@ -523,8 +576,9 @@ def emit_node(node: Node, body): # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, - self.eval_mem) + emit_code_with_chunk( + body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem + ) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 77bc2ef17bc3..a85ad429e261 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -1,11 +1,8 @@ -import copy -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Dict, List import torch from torch.fx.node import Node -from colossalai.fx.profiler import activation_size, parameter_size - from .utils import NodeMgr, get_node_shape, is_non_memory_node @@ -62,12 +59,9 @@ def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict: delete_node_dict[node] = max(node_user_idx) return delete_node_dict - def _remove_deactive_node(self, - user_idx: int, - user: Node, - active_nodes: List, - delete_node_dict: List, - kept_nodes: List = None) -> None: + def _remove_deactive_node( + self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None + ) -> None: """ remove deactivate nodes from active nodes """ @@ -169,7 +163,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem + chunk_ratio = 1 # use it to estimate chunk mem chunk_inputs_all = [] if use_chunk: @@ -184,7 +178,6 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] for idx, node in enumerate(node_mgr.get_node_list()): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in chunk_starts: chunk_within = True @@ -193,8 +186,9 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None # determine chunk ratio for current node if chunk_within: - chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx]) + chunk_ratio = self._get_chunk_ratio( + node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx] + ) # add current node as active node self._add_active_node(node, active_nodes, chunk_ratio) @@ -222,7 +216,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None # if node in chunk end nodes, restore chunk settings if use_chunk and idx in chunk_ends: - self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now chunk_within = False chunk_ratio = 1 chunk_region_idx = None diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 59645c80e808..1c599049d9eb 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,7 @@ from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -121,8 +121,10 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re # check if peak node already in chunk info if chunk_regions is not None: for i in chunk_regions: - if i["region"][0] < peak_region[0] <= i["region"][1] or \ - i["region"][0] < peak_region[1] <= i["region"][1]: + if ( + i["region"][0] < peak_region[0] <= i["region"][1] + or i["region"][0] < peak_region[1] <= i["region"][1] + ): return None active_node_num = [len(i) for i in active_node] @@ -146,9 +148,9 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re region = i["region"] if chunk_region_start >= region[0] and chunk_region_end <= region[1]: return None - elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]: chunk_region_start = region[1] + 1 - elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]: chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end @@ -171,7 +173,7 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis chunk_infos: possible regions found """ start_traces = input_trace[start_idx] - if len(start_traces) > 1: # TODO need to be removed + if len(start_traces) > 1: # TODO need to be removed return [] end_trace = output_trace[end_idx] end_node = self.node_mgr.get_node_by_idx(end_idx) @@ -180,8 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis for end_dim, _ in enumerate(end_trace["indice"]): for start_node, start_trace in start_traces.items(): for start_dim, _ in enumerate(start_trace["indice"]): - if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, - end_idx): + if not self.trace_flow.check_region_start_end( + start_node, start_dim, start_idx, end_node, end_dim, end_idx + ): continue # flow search chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) @@ -203,7 +206,7 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N """ possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) - input_trace = [] # trace of a node's input nodes + input_trace = [] # trace of a node's input nodes for _, n in enumerate(self.node_mgr.get_node_list()): cur_trace = {} for arg in n.args: @@ -215,7 +218,8 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N for end_idx in range(peak_region[1], max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( - self.node_mgr.get_node_by_idx(end_idx)): + self.node_mgr.get_node_by_idx(end_idx) + ): continue # select free dim chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) @@ -279,15 +283,18 @@ def search_region(self) -> Dict: chunk_infos.append(chunk_info) mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( - self.node_mgr.get_node_list(), chunk_infos) + self.node_mgr.get_node_list(), chunk_infos + ) if self.print_progress: - get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % - (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + get_logger().info( + "AutoChunk find chunk region %d = (%d, %d)" + % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]) + ) if self.print_mem: self.print_mem = False - self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), - chunk_infos, - print_mem=True) + self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos, print_mem=True + ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 94a29bfd5691..8a60ba681f70 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -5,7 +5,6 @@ class SelectChunk(object): - def __init__( self, trace_indice: TraceIndice, @@ -20,7 +19,7 @@ def __init__( self.node_mgr = node_mgr if max_memory is not None: self.stratge = "fit_memory" - self.max_memory = max_memory # MB + self.max_memory = max_memory # MB else: self.stratge = "min_memory" @@ -57,16 +56,18 @@ def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, m cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] + cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # no region found if len(regions_dict) == 0: raise RuntimeError("Search failed. Try a larger memory threshold.") @@ -90,13 +91,15 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size *= 2 reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]) # search exact size chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, - chunk_infos) + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos + ) return chunk_info def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): @@ -109,9 +112,10 @@ def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos) mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]) if cur_chunk_max_mem >= self.max_memory: right = mid - gap else: @@ -139,8 +143,10 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): return None # get max possible chunk region - max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), - max([i["region"][1] for i in possible_chunk_regions])) + max_possible_chunk_region = ( + min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions]), + ) # get mem for chunk region regions_dict_list = [] @@ -149,15 +155,17 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict_list.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict_list.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # select the min mem chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] @@ -175,7 +183,9 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos): return False for i in chunk_infos: region = i["region"] - if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or - (chunk_region_start < region[0] and chunk_region_end < region[0])): + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): return False return True diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index a1080fda1541..8b36c99bbadd 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -16,7 +16,6 @@ class TraceFlow(object): - def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice self.node_mgr = node_mgr @@ -151,7 +150,7 @@ def _assign_single_node_flow( return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node + cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -266,7 +265,7 @@ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, maybe_prepose_nodes.sort( key=lambda x: self.node_mgr.find_node_idx(x), reverse=True, - ) # from last node to first node + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: @@ -328,7 +327,8 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) + self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) + ) # get every node's chunk dim and fix dim all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) @@ -371,8 +371,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): return chunk_info - def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, - chunk_info: Dict): + def _get_other_output_info( + self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict + ): start_node = self.node_mgr.get_node_by_idx(start_idx) # loop all outputs for output in outputs: @@ -384,8 +385,8 @@ def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: # skip non tensor if get_node_shape(output) is None: # log shape tensor - if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): - chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) + if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int): + chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"]) continue # loop every dim of outputs, try to find a legal one for output_dim in range(len(get_node_shape(output))): @@ -421,7 +422,8 @@ def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: for k, v in new_all_node_info.items(): if k in chunk_info["node_chunk_dim"]: chunk_info["node_chunk_dim"][k]["fix_dim"] = list( - set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) + set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]) + ) else: chunk_info["node_chunk_dim"][k] = v chunk_info["outputs"].append(output) @@ -443,8 +445,11 @@ def _reassign_reshape_size(self, chunk_info): if node.args[0] in chunk_info["inputs_non_chunk"]: continue reshape_args = flat_list(node.args[1:]) - if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( - reshape_args[0].meta['fwd_out']) > 1: + if ( + len(reshape_args) == 1 + and get_node_shape(reshape_args[0]) is None + and len(reshape_args[0].meta["fwd_out"]) > 1 + ): continue chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" @@ -462,16 +467,17 @@ def _reassign_reshape_size(self, chunk_info): chunk_info["reshape_size"] = reshape_size return chunk_info - def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, - end_idx: int) -> bool: + def check_region_start_end( + self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int + ) -> bool: """ check if region start and end is legal """ # dim cannot be None - if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): + if get_node_shape(end_node) is None or get_node_shape(start_node) is None: return False # dim size cannot be 1 - if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): + if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1: return False # must have users if len(end_node.users) == 0: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index fbe0741b8827..378c54acf782 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Tuple +from typing import Dict, List from torch.fx.node import Node @@ -412,7 +412,7 @@ def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: node_idx (int) """ # get conv input - assert node.kwargs['size'] is None + assert node.kwargs["size"] is None assert len(get_node_shape(node)) == 4 # assign index @@ -826,7 +826,7 @@ def _clear_trace(self, node_idx: int) -> None: # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): - if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): + if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes: dim_compute.pop(i) continue # clear source @@ -876,10 +876,24 @@ def trace_indice(self) -> None: self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) - elif any(n == node_name for n in [ - "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", - "sin", "cos" - ]): + elif any( + n == node_name + for n in [ + "mul", + "add", + "sigmoid", + "relu", + "sub", + "truediv", + "pow", + "dropout", + "where", + "tanh", + "exp", + "sin", + "cos", + ] + ): self._assign_elementwise_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) @@ -920,7 +934,7 @@ def trace_indice(self) -> None: else: raise NotImplementedError(node_name, "module not implemented yet!") elif node.op == "get_attr": - self._assign_all_indice(node, idx) # get param + self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 064baa047155..f6f803a5ce0a 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Dict, List, Union from torch.fx.node import Node @@ -10,7 +10,6 @@ class NodeMgr(object): - def __init__(self, nodes_list: List[Node]) -> None: self._node_list = nodes_list self._node_dict = {} @@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if (input_node not in nodes and input_node not in input_nodes - and not is_non_compute_node_except_placeholder(input_node)): + if ( + input_node not in nodes + and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node) + ): input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output for node in nodes: for output_node in node.users.keys(): - if (output_node not in nodes and node not in output_nodes - and not is_non_compute_node_except_placeholder_output(output_node)): + if ( + output_node not in nodes + and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node) + ): output_nodes.append(node) return input_nodes, output_nodes @@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: for node in node_list: if get_node_shape(node) is not None: out.append(node) - elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( - node.meta['fwd_out'][0], int): + elif ( + len(node.meta["fwd_out"]) > 0 + and isinstance(node.meta["fwd_out"], list) + and isinstance(node.meta["fwd_out"][0], int) + ): out.append(node) return out diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py index fc2c4a40068b..92990907bc2e 100644 --- a/colossalai/booster/accelerator.py +++ b/colossalai/booster/accelerator.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn -__all__ = ['Accelerator'] +__all__ = ["Accelerator"] _supported_devices = [ - 'cpu', - 'cuda', - + "cpu", + "cuda", # To be supported # 'xpu', # 'npu', @@ -25,21 +24,22 @@ class Accelerator: def __init__(self, device: str): self.device = device - assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" + assert ( + self.device in _supported_devices + ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" def bind(self): """ Set the default device for the current process. """ - if self.device == 'cpu': + if self.device == "cpu": pass - elif self.device == 'cuda': + elif self.device == "cuda": # TODO(FrankLeeeee): use global environment to check if it is a dist job # if is_distributed: # local_rank = EnvTable().get_local_rank() # torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) - torch.cuda.set_device(torch.device('cuda')) - pass + torch.cuda.set_device(torch.device("cuda")) else: raise ValueError(f"Device {self.device} is not supported yet") diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index fb9dae7c9650..2aee72cbf2f1 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -16,7 +16,7 @@ from .plugin import Plugin from .plugin.pp_plugin_base import PipelinePluginBase -__all__ = ['Booster'] +__all__ = ["Booster"] class Booster: @@ -60,28 +60,31 @@ class Booster: plugin (Plugin): The plugin to run the training. Default: None. """ - def __init__(self, - device: Optional[str] = None, - mixed_precision: Optional[Union[MixedPrecision, str]] = None, - plugin: Optional[Plugin] = None) -> None: + def __init__( + self, + device: Optional[str] = None, + mixed_precision: Optional[Union[MixedPrecision, str]] = None, + plugin: Optional[Plugin] = None, + ) -> None: if plugin is not None: assert isinstance( - plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' + plugin, Plugin + ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') + warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") else: - device = device or 'cuda' + device = device or "cuda" self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -95,7 +98,7 @@ def __init__(self, self.mixed_precision = mixed_precision else: raise ValueError( - f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}." ) if self.plugin is not None and self.plugin.control_checkpoint_io(): @@ -131,7 +134,8 @@ def boost( # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( - model, optimizer, criterion, dataloader, lr_scheduler) + model, optimizer, criterion, dataloader, lr_scheduler + ) if self.plugin and not self.plugin.control_device(): # transform model for accelerator @@ -154,13 +158,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: # TODO(frank lee): implement this method with plugin optimizer.backward(loss) - def execute_pipeline(self, - data_iter: Iterator, - model: nn.Module, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[Optimizer] = None, - return_loss: bool = True, - return_outputs: bool = False) -> Dict[str, Any]: + def execute_pipeline( + self, + data_iter: Iterator, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[Optimizer] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> Dict[str, Any]: """ Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed. @@ -185,8 +191,9 @@ def execute_pipeline(self, ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None. """ - assert isinstance(self.plugin, - PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.' + assert isinstance( + self.plugin, PipelinePluginBase + ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline." return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: @@ -200,8 +207,10 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) - Returns: contextmanager: Context to disable gradient synchronization. """ - assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + assert ( + self.plugin is not None + ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync." + assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: @@ -217,14 +226,16 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str """ self.checkpoint_io.load_model(model, checkpoint, strict) - def save_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False) -> None: + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: """Save model to checkpoint. Args: @@ -239,13 +250,15 @@ def save_model(self, size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. """ - self.checkpoint_io.save_model(model, - checkpoint=checkpoint, - shard=shard, - gather_dtensor=gather_dtensor, - prefix=prefix, - size_per_shard=size_per_shard, - use_safetensors=use_safetensors) + self.checkpoint_io.save_model( + model, + checkpoint=checkpoint, + shard=shard, + gather_dtensor=gather_dtensor, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + ) def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: """Load optimizer from checkpoint. @@ -260,13 +273,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) - def save_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024) -> None: + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ) -> None: """ Save optimizer to checkpoint. diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py index 0df9d84159f9..68c6221ec809 100644 --- a/colossalai/booster/mixed_precision/__init__.py +++ b/colossalai/booster/mixed_precision/__init__.py @@ -6,16 +6,22 @@ from .mixed_precision_base import MixedPrecision __all__ = [ - 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', - 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision' + "MixedPrecision", + "mixed_precision_factory", + "FP16_Apex_MixedPrecision", + "FP16_Torch_MixedPrecision", + "FP32_MixedPrecision", + "BF16_MixedPrecision", + "FP8_MixedPrecision", + "FP16NaiveMixedPrecision", ] _mixed_precision_mapping = { - 'fp16': FP16TorchMixedPrecision, - 'fp16_apex': FP16ApexMixedPrecision, - 'fp16_naive': FP16NaiveMixedPrecision, - 'bf16': BF16MixedPrecision, - 'fp8': FP8MixedPrecision + "fp16": FP16TorchMixedPrecision, + "fp16_apex": FP16ApexMixedPrecision, + "fp16_naive": FP16NaiveMixedPrecision, + "bf16": BF16MixedPrecision, + "fp8": FP8MixedPrecision, } @@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: return _mixed_precision_mapping[mixed_precision_type]() else: raise ValueError( - f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' + f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}" ) diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py index e184271e932a..2fa7b54cdd30 100644 --- a/colossalai/booster/mixed_precision/fp16_apex.py +++ b/colossalai/booster/mixed_precision/fp16_apex.py @@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision): max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. """ - def __init__(self, - opt_level: Optional[str] = "O1", - cast_model_type: torch.dtype = None, - patch_torch_functions: bool = None, - keep_batchnorm_fp32: Union[bool, str] = None, - master_weights: bool = None, - loss_scale: Union[float, str] = None, - cast_model_outputs: Any = None, - num_losses: Optional[int] = 1, - verbosity: int = 1, - min_loss_scale: float = None, - max_loss_scale: float = 2.**24) -> None: + def __init__( + self, + opt_level: Optional[str] = "O1", + cast_model_type: torch.dtype = None, + patch_torch_functions: bool = None, + keep_batchnorm_fp32: Union[bool, str] = None, + master_weights: bool = None, + loss_scale: Union[float, str] = None, + cast_model_outputs: Any = None, + num_losses: Optional[int] = 1, + verbosity: int = 1, + min_loss_scale: float = None, + max_loss_scale: float = 2.0**24, + ) -> None: pass diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py index 5d0d815257f3..e5624a9d7477 100644 --- a/colossalai/booster/mixed_precision/fp16_naive.py +++ b/colossalai/booster/mixed_precision/fp16_naive.py @@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision): verbose(bool): if set to `True`, will print debug info. """ - def __init__(self, - log_num_zeros_in_grad: bool, - initial_scale: int, - growth_factor: int, - backoff_factor: float, - hysteresis: int, - max_scale: int, - verbose: bool = None) -> None: + def __init__( + self, + log_num_zeros_in_grad: bool, + initial_scale: int, + growth_factor: int, + backoff_factor: float, + hysteresis: int, + max_scale: int, + verbose: bool = None, + ) -> None: pass diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 26fd92bd50b8..7dce6e6da33e 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -9,7 +9,7 @@ from .mixed_precision_base import MixedPrecision -__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] +__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"] class TorchAMPOptimizer(OptimizerWrapper): @@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - optim: Optimizer, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + optim: Optimizer, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__(optim) - self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) + self.scaler = torch.cuda.amp.GradScaler( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) def backward(self, loss: Tensor, *args, **kwargs) -> None: scaled_loss = self.scale_loss(loss) @@ -60,12 +64,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: self.unscale_grad() super().clip_grad_by_value(clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> None: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> None: self.unscale_grad() super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) @@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__() - self.torch_amp_kwargs = dict(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) - - def configure(self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + self.torch_amp_kwargs = dict( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model = TorchAMPModule(model) if optimizer is not None: optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index f48bf38bd724..62f3708fc629 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -4,11 +4,12 @@ from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin'] +__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] import torch from packaging import version -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from .torch_fsdp_plugin import TorchFSDPPlugin - __all__.append('TorchFSDPPlugin') + + __all__.append("TorchFSDPPlugin") diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d5da5938bfd9..d2dd00453e32 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -10,25 +10,19 @@ class DPPluginBase(Plugin): - """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation. - """ + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.""" def __init__(self) -> None: super().__init__() - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + assert ( + dist.is_initialized() + ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment" self.rank = dist.get_rank() self.world_size = dist.get_world_size() - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. @@ -60,11 +54,13 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index de03ba27bfda..83a00d4ee229 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -27,14 +27,13 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['GeminiPlugin'] +__all__ = ["GeminiPlugin"] -SUPPORTED_PRECISION = ['fp16', 'bf16'] -PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} +SUPPORTED_PRECISION = ["fp16", "bf16"] +PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} class GeminiCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -74,13 +73,15 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): """ super().load_unsharded_optimizer(optimizer, checkpoint) - def save_sharded_model(self, - model: GeminiDDP, - checkpoint_path: str, - gather_dtensor: bool = False, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. @@ -97,34 +98,37 @@ def save_sharded_model(self, # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=is_master, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, + ) # only save the index file on the master rank if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.module, checkpoint_path) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") - - def load_sharded_model(self, - model: GeminiDDP, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False + ): """ Load shard model, load model from multiple files. """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save sharded optimizer state dict to checkpoint folder. As there is communication when getting state dict, this must be called on all processes. @@ -153,20 +157,24 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_ # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=is_master, - use_safetensors=False) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=is_master, + use_safetensors=False, + ) # Wrap up index file. Only save it on master rank. if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): """ @@ -185,8 +193,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa # Load param_groups. param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) saved_param_groups = torch.load(param_group_path) optimizer.load_param_groups(saved_param_groups) @@ -274,11 +284,11 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - warmup_non_model_data_ratio: float = 0.8, # only for auto placement - steady_cuda_cap_ratio: float = 0.9, # only for auto placement + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, @@ -300,7 +310,7 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() - assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' + assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), @@ -319,16 +329,20 @@ def __init__( memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], ) - self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) + self.zero_optim_config = dict( + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + ) + self.optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + ) self.verbose = verbose def support_no_sync(self) -> bool: @@ -344,7 +358,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -354,7 +368,6 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -368,13 +381,10 @@ def configure( # wrap the model with Gemini model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(optimizer, - model.unwrap(), - **self.zero_optim_config, - **self.optim_kwargs, - verbose=self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer( + optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d15245523226..c1693fa8d3a1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): - - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, custom_policy: Policy) -> None: - + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group @@ -54,13 +60,14 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( - self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) # setting mixed_precision self.mixed_precision = None - if precision == 'fp16': + if precision == "fp16": self.mixed_precision = torch.float16 - elif precision == 'bf16': + elif precision == "bf16": self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) @@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} start_index = 0 for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] - packed_group = {k: v for k, v in group.items() if k != 'params'} - packed_group['params'] = [] - - for param_id, param in enumerate(group['params'], start_index): + for param_id, param in enumerate(group["params"], start_index): original_shape = param.shape if isinstance(param, torch.Tensor) else None - packed_group['params'].append(param_id) - param_info['param2id'][id(param)] = param_id - param_info['id2param'][param_id] = id(param) - param_info['param2shape'][id(param)] = original_shape + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape - param_info['param_groups'].append(packed_group) - start_index += len(group['params']) + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) return param_info @@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in model_params] - new_param_groups.append({**group, 'params': params}) - optim.__setstate__({'param_groups': new_param_groups}) + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): self.param_info = param_info if use_pipeline: @@ -162,60 +167,87 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_in class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): - - def __init__(self, - optim: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - precision: str = 'fp16', - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0): + def __init__( + self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) - super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, max_norm) + super().__init__( + optim, + precision, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + max_norm, + ) class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) - super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, - overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, - forced_dtype) + super().__init__( + optimizer, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + clip_grad_norm, + verbose, + reduce_bucket_size, + communication_dtype, + overlap_communication, + partition_grad, + cpu_offload, + dp_process_group, + tp_process_group, + forced_dtype, + ) class HybridParallelPlugin(PipelinePluginBase): @@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. """ - def __init__(self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - enable_sequence_overlap: bool = False, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - custom_policy: Policy = None) -> None: - + def __init__( + self, + tp_size: int, + pp_size: int, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + ) -> None: super().__init__() - assert dist.get_world_size() % ( - tp_size * pp_size - ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" self.tp_size = tp_size self.pp_size = pp_size @@ -334,24 +367,28 @@ def __init__(self, self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' - assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size) + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - enable_sequence_overlap=enable_sequence_overlap) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap, + ) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, @@ -362,18 +399,22 @@ def __init__(self, max_scale=max_scale, ) - self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) - self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2)) + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + ) self.max_norm = max_norm @@ -382,10 +423,10 @@ def enable_pipeline_parallelism(self) -> bool: return self.pp_size > 1 def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16', 'fp32'] + return ["fp16", "bf16", "fp32"] def control_device(self) -> bool: return True @@ -410,57 +451,67 @@ def configure( param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config, self.custom_policy) + model = HybridParallelModule( + model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: - if self.precision in ['fp16', 'bf16']: - optimizer = HybridParallelAMPOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, - optimizer.master_to_working_map) + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer.working_to_master_map, optimizer.master_to_working_map + ) else: - optimizer = HybridParallelNaiveOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info) + optimizer = HybridParallelNaiveOptimizer( + optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." - assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=self.dp_group, - tp_process_group=self.tp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **self.zero_config, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, - optimizer._param_store.master_to_working_param) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param + ) return model, optimizer, criterion, dataloader, lr_scheduler - def execute_pipeline(self, - data_iter: Iterator, - model: HybridParallelModule, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, - HybridParallelZeroOptimizer]] = None, - return_loss: bool = True, - return_outputs: bool = False) -> dict: - assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" # return loss or outputs if needed ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx: - outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, - return_outputs) + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) model.sync_shared_params() if isinstance(optimizer, HybridParallelZeroOptimizer): optimizer.sync_grad() @@ -468,15 +519,9 @@ def execute_pipeline(self, model.sync_grads() return outputs - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. @@ -499,10 +544,9 @@ def prepare_dataloader(self, :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, - num_replicas=self.pg_mesh.size(DP_AXIS), - rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) + sampler = DistributedSampler( + dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + ) # Deterministic dataloader def seed_worker(worker_id): @@ -511,14 +555,16 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def get_checkpoint_io(self) -> CheckpointIO: self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 9adb4beec9b9..86adee7fe226 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,14 +1,12 @@ import logging import os -import warnings from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -33,7 +31,7 @@ from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO -__all__ = ['LowLevelZeroPlugin'] +__all__ = ["LowLevelZeroPlugin"] def _convert_floating_point(x, dtype: torch.dtype = torch.float16): @@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): return x -SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] +SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str) -> None: super().__init__(module) self.dtype = None - if precision == 'fp16': + if precision == "fp16": self.dtype = torch.float16 - elif precision == 'bf16': + elif precision == "bf16": self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) @@ -74,7 +71,6 @@ def unwrap(self): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): """Save optimizer to checkpoint but only on master process. @@ -91,12 +87,14 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) - def save_sharded_optimizer(self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = False, - prefix: str = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024, + ): """ Save sharded Zero-optimizer checkpoint under the given checkpointing path. The following files will be created under the path: @@ -148,9 +146,11 @@ def save_sharded_optimizer(self, index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): """Load sharded optimizer with the given path to index file. @@ -170,8 +170,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) id_map = load_param_groups_into_optimizer(optimizer, param_group_path) checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() @@ -181,9 +183,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s # shard state dict for param_idx, state in state_dict.items(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': - padding_size = (self.coordinator.world_size - - v.numel() % self.coordinator.world_size) % self.coordinator.world_size + if isinstance(v, torch.Tensor) and k != "step": + padding_size = ( + self.coordinator.world_size - v.numel() % self.coordinator.world_size + ) % self.coordinator.world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: @@ -194,33 +197,39 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s sharded_optimizer_loading_epilogue(optimizer) - def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, - use_safetensors: bool): + def save_unsharded_model( + self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool + ): assert isinstance(model, LowLevelZeroModel) super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): assert isinstance(model, LowLevelZeroModel) - super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, - use_safetensors) + super().save_sharded_model( + model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): assert isinstance(model, LowLevelZeroModel) super().load_unsharded_model(model.module, checkpoint, strict) model.update_master_params() - def load_sharded_model(self, - model: LowLevelZeroModel, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + def load_sharded_model( + self, + model: LowLevelZeroModel, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): assert isinstance(model, LowLevelZeroModel) super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() @@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase): def __init__( self, stage: int = 1, - precision: str = 'fp16', + precision: str = "fp16", initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, @@ -281,9 +290,9 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() - assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' - assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now' + assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" + assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" + assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" self.stage = stage self.precision = precision self.zero_optim_kwargs = dict( @@ -319,7 +328,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -329,15 +338,13 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): - optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer, - **self.zero_optim_kwargs, - verbose=self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( + optimizer, **self.zero_optim_kwargs, verbose=self.verbose + ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index fb21e57f41f7..4e570cbe8abc 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -9,11 +9,10 @@ from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import OptimizerWrapper -__all__ = ['Plugin'] +__all__ = ["Plugin"] class Plugin(ABC): - @abstractmethod def supported_devices(self) -> List[str]: pass @@ -51,33 +50,31 @@ def control_checkpoint_io(self) -> bool: """ Whether the plugin controls the checkpoint io """ - pass @abstractmethod def get_checkpoint_io(self) -> CheckpointIO: """ Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ - pass @abstractmethod def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: """ Context manager to disable gradient synchronization. """ - pass @abstractmethod - def prepare_dataloader(self, - dataset: Dataset, - batch_size: int, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - **kwargs): + def prepare_dataloader( + self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs, + ): """Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` """ - pass diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py index f52844db082f..3d91eb95b409 100644 --- a/colossalai/booster/plugin/pp_plugin_base.py +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -9,13 +9,14 @@ class PipelinePluginBase(Plugin): - @abstractmethod - def execute_pipeline(self, - data_iter: Iterator, - model: ModelWrapper, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = True, - return_outputs: bool = False) -> dict: + def execute_pipeline( + self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index f3f779c88e42..30d34e7dd5e5 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -12,11 +12,10 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['TorchDDPPlugin'] +__all__ = ["TorchDDPPlugin"] class TorchDDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -49,25 +48,29 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ Save model to checkpoint but only on master process. """ if self.coordinator.is_master(): super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) - def save_sharded_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): """ Save optimizer to checkpoint but only on master process. """ @@ -76,7 +79,6 @@ def save_sharded_optimizer(self, class TorchDDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: super().__init__(module) self.module = DDP(module, *args, **kwargs) @@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase): static_graph (bool, optional): Whether to use static graph. Defaults to False. """ - def __init__(self, - broadcast_buffers: bool = True, - bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False) -> None: + def __init__( + self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: super().__init__() - self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + self.ddp_kwargs = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) def support_no_sync(self) -> bool: return True @@ -131,13 +137,13 @@ def control_precision(self) -> bool: return False def supported_precisions(self) -> List[str]: - return ['fp16', 'fp16_apex', 'bf16', 'fp8'] + return ["fp16", "fp16_apex", "bf16", "fp8"] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -156,8 +162,7 @@ def configure( # wrap the model with PyTorch DDP model = TorchDDPModel(model, **self.ddp_kwargs) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler @@ -169,5 +174,5 @@ def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' + assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index fb7b5baadd0c..d12b784b4fc1 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,13 +1,13 @@ import warnings from pathlib import Path -from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn from packaging import version from torch.distributed import ProcessGroup -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -31,11 +31,10 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['TorchFSDPPlugin'] +__all__ = ["TorchFSDPPlugin"] class TorchFSDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -69,26 +68,36 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): """ Save model to checkpoint but only on master process. """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def load_sharded_model(self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): """ Load model to checkpoint but only on master process. """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save optimizer to checkpoint but only on master process. """ @@ -109,7 +118,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): class TorchFSDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: super().__init__(module) self.module = FSDP(module, *args, **kwargs) @@ -119,7 +127,6 @@ def unwrap(self): class FSDPOptimizerWrapper(OptimizerWrapper): - def __init__(self, optimizer: Optimizer, model: nn.Module): self.model = model super().__init__(optimizer) @@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase): See https://pytorch.org/docs/stable/fsdp.html for details. """ - if version.parse(torch.__version__) >= version.parse('1.12.0'): + if version.parse(torch.__version__) >= version.parse("1.12.0"): def __init__( self, @@ -162,15 +169,18 @@ def __init__( sync_module_states: bool = False, ): super().__init__() - self.fsdp_kwargs = dict(process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - sync_module_states=sync_module_states) + self.fsdp_kwargs = dict( + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + sync_module_states=sync_module_states, + ) + else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -184,13 +194,13 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16'] + return ["fp16", "bf16"] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -200,14 +210,13 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) if optimizer is not None: if len(optimizer.param_groups) > 1: warnings.warn( - 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' + "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index e1aa6543ef39..19b61730bded 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -3,4 +3,4 @@ from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] +__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index baff24e1cb25..f8ce8f4e5210 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -11,7 +11,7 @@ from .utils import has_index_file -__all__ = ['CheckpointIO'] +__all__ = ["CheckpointIO"] class CheckpointIO(ABC): @@ -61,10 +61,9 @@ class CheckpointIO(ABC): # ====================================== # Public methods # ====================================== - def load_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - strict: bool = True) -> Union[nn.Module, ModelWrapper]: + def load_model( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True + ) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. @@ -98,14 +97,16 @@ def load_model(self, return origin_model - def save_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: str = None, - size_per_shard: int = 1024, - use_safetensors: bool = False): + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: str = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ): """ Save model to checkpoint. @@ -157,7 +158,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No if Path(checkpoint).is_dir() and not index_file_exists: # if the checkpoint is a directory and there is no index file, raise error - raise ValueError(f'Cannot find index file in {checkpoint}') + raise ValueError(f"Cannot find index file in {checkpoint}") if index_file_exists: # the existence of index file means it is a sharded checkpoint @@ -165,13 +166,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No else: self.load_unsharded_optimizer(optimizer, checkpoint) - def save_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - shard: bool = False, - gather_dtensor=True, - prefix: str = None, - size_per_shard: int = 1024): + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor=True, + prefix: str = None, + size_per_shard: int = 1024, + ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. @@ -207,7 +210,6 @@ def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: boo strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): @@ -220,11 +222,17 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): """ Save model to sharded checkpoint. @@ -236,7 +244,6 @@ def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: size_per_shard (int): size per shard in MB. use_safetensors (bool): whether to use safe tensors. """ - pass @abstractmethod def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): @@ -249,7 +256,6 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor gather_dtensor (bool): whether to gather the distributed tensor to the first device. use_safetensors (bool): whether to use safe tensors. """ - pass # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -265,7 +271,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. """ - pass @abstractmethod def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): @@ -276,11 +281,11 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. """ - pass @abstractmethod - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save optimizer to sharded checkpoint. @@ -291,7 +296,6 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_ prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ - pass @abstractmethod def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): @@ -303,7 +307,6 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gathe checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ - pass # ============================================ # methods for loading and saving lr scheduler diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index faaf1d22722a..b0e593e90d8c 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -3,9 +3,8 @@ import os from functools import reduce from pathlib import Path -from typing import Iterator, Optional, OrderedDict, Tuple +from typing import Optional -import torch.distributed as dist import torch.nn as nn from torch.optim import Optimizer @@ -16,7 +15,6 @@ from .utils import ( get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, load_param_groups_into_optimizer, load_shard_state_dict, @@ -33,7 +31,7 @@ unwrap_optimizer, ) -__all__ = ['GeneralCheckpointIO'] +__all__ = ["GeneralCheckpointIO"] class GeneralCheckpointIO(CheckpointIO): @@ -70,8 +68,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) id_map = load_param_groups_into_optimizer(optimizer, param_group_path) checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() @@ -123,19 +123,23 @@ def save_sharded_optimizer( # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards(sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=True, - use_safetensors=False) + total_size = save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False, + ) # Wrap up index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): checkpoint = load_state_dict(checkpoint) @@ -150,13 +154,15 @@ def save_unsharded_optimizer( # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = False, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files @@ -175,26 +181,32 @@ def save_sharded_model(self, # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + ) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint_path, is_master=True) - logging.info(f"The model is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") - - def load_sharded_model(self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + logging.info( + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): """ load shard model, load model from multiple files """ @@ -219,7 +231,11 @@ def load_sharded_model(self, if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 270fd8564754..18c59a880dd6 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,5 +1,4 @@ import copy -import gc import logging import os from pathlib import Path @@ -35,9 +34,9 @@ ) try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" class HybridParallelCheckpointIO(GeneralCheckpointIO): @@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, - dp_group: ProcessGroup, - pp_group: ProcessGroup, - tp_group: ProcessGroup, - zero_stage: int, - verbose: bool = True) -> None: + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -68,17 +69,16 @@ def __init__(self, self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) - self.use_zero = (zero_stage > 0) + self.use_zero = zero_stage > 0 self.verbose = verbose self.working_to_master_map = None self.master_to_working_map = None self.coordinator = DistCoordinator() @staticmethod - def _model_sharder(model: nn.Module, - prefix: str = '', - keep_vars: bool = False, - size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + def _model_sharder( + model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) @@ -103,8 +103,10 @@ def _model_sharder(model: nn.Module, # Save extra states. extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): extra_state = model.get_extra_state() block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: @@ -114,20 +116,20 @@ def _model_sharder(model: nn.Module, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: OptimizerWrapper, - use_zero: bool, - dp_group: ProcessGroup, - tp_group: ProcessGroup, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, - size_per_shard: int = 1024): - + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024, + ): # An internel method that breaks state_dict of optimizer into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) param_info = optimizer.param_info for param, state in optimizer.optim.state.items(): - if param is None: continue @@ -136,15 +138,17 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, else: working_param = param - param_id = param_info['param2id'][id(working_param)] - original_shape = param_info['param2shape'][id(working_param)] - state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, - working_param, - original_shape=original_shape, - dp_group=dp_group, - tp_group=tp_group, - use_zero=use_zero, - inplace=False) + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + ) block, block_size = state_dict_sharder.append_optim_state(param_id, state_) if block is not None: @@ -153,13 +157,15 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, # Return the last block in sharder. yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - def save_sharded_model(self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False) -> None: + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: """ Save sharded model checkpoint under the given checkpointing path. The following files will be created under the path: @@ -194,24 +200,28 @@ def save_sharded_model(self, state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = (self.tp_rank == 0) + control_saving = self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -228,15 +238,19 @@ def save_sharded_model(self, save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) if control_saving: - assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: @@ -259,9 +273,11 @@ def save_sharded_model(self, save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -305,11 +321,9 @@ def _load(name: str): state_dict = load_shard_state_dict(Path(file_path), use_safetensors) missing_keys = [] - load_state_dict_into_model(model, - state_dict, - missing_keys=missing_keys, - strict=strict, - load_sub_module=True) + load_state_dict_into_model( + model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True + ) loaded_file.add(filename) # Load parameters. @@ -319,15 +333,17 @@ def _load(name: str): # Load buffers. non_persistent_buffers = set() for n, m in model.named_modules(): - non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) + non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. extra_state_key = _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): _load(extra_state_key) # Update master params if mixed-precision training is enabled. @@ -352,12 +368,14 @@ def _load(name: str): if self.verbose: logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def save_sharded_optimizer(self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): """ Save sharded optimizer checkpoint under the given checkpointing path. The following files will be created under the path: @@ -393,18 +411,21 @@ def save_sharded_optimizer(self, dp_group=self.dp_group, tp_group=self.tp_group, master_to_working_map=self.master_to_working_map, - size_per_shard=size_per_shard) + size_per_shard=size_per_shard, + ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) - control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + control_saving = self.dp_rank == 0 and self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) if control_saving: # Store param groups. @@ -415,9 +436,11 @@ def save_sharded_optimizer(self, index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) if self.verbose: - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -433,15 +456,19 @@ def save_sharded_optimizer(self, save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) if control_saving: - assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: @@ -451,7 +478,6 @@ def save_sharded_optimizer(self, # The global master rank integrates the index files and clean the folder. if self.pp_rank == 0: - final_index_file = CheckpointIndexFile(checkpoint) final_index_file.append_meta_data("total_size", 0) @@ -470,9 +496,11 @@ def save_sharded_optimizer(self, rmtree(tmp_index_file_folder) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): """ @@ -484,20 +512,21 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f prefix (str): Not used. """ - def _get_param_id_from_optimizer_param(param: torch.Tensor, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): if master_to_working_map is not None: working_param = master_to_working_map[id(param)] else: working_param = param - return optimizer.param_info['param2id'][id(working_param)] + return optimizer.param_info["param2id"][id(working_param)] # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. # When Zero is used, the mapped parameter objects should be fp32 master parameters. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. id_map = {} for pg in optimizer.optim.param_groups: - for param in pg['params']: + for param in pg["params"]: param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) id_map[param_id] = param @@ -505,28 +534,30 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_root_path = ckpt_index_file.root_path weight_map = ckpt_index_file.weight_map - weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) saved_groups = torch.load(param_group_path) updated_groups = [] for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): # obtain updated param group new_pg = copy.deepcopy(saved_pg) - new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - optimizer.optim.__dict__.update({'param_groups': updated_groups}) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) # Load saved states to optimizer. # Keep a record of loaded files so that file will not be repeatedly loaded. loaded_file = set() for pg in optimizer.optim.param_groups: - for param in pg['params']: + for param in pg["params"]: if param is None: continue param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) @@ -550,12 +581,10 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, working_param = self.master_to_working_map[id(param)] else: working_param = param - original_shape = optimizer.param_info['param2shape'][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state(state, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True) + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) @@ -585,8 +614,11 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], - master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + def link_master_and_working_param( + self, + working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor], + ): """ Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. This mapping can only be created when mixied precision is used. @@ -604,7 +636,8 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t self.working_to_master_map[k] = v else: raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" + ) self.master_to_working_map = dict() for k, v in master_to_working_map.items(): @@ -614,12 +647,19 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t self.master_to_working_map[k] = v else: raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" + ) @staticmethod - def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, - dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, - inplace: bool) -> OrderedDict: + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -641,14 +681,13 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != 'step': - + if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: v = v.cuda() gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) - v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) @@ -661,9 +700,14 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, return state_ - def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, - original_shape: torch.Size, device: torch.device, - inplace: bool) -> OrderedDict: + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + ) -> OrderedDict: """ With complete optimizer states of a specific parameter loaded from checkpoint, slice out the sharded optimizer states kept by current device. @@ -681,8 +725,7 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != 'step': - + if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) if partition_dim is not None: diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 388cf3fbe9bb..da12c146f2c3 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -6,7 +6,7 @@ from .utils import is_dtensor_checkpoint -__all__ = ['CheckpointIndexFile'] +__all__ = ["CheckpointIndexFile"] class CheckpointIndexFile: @@ -50,7 +50,7 @@ def load(self, json_path: str): json_path (str): path to the json file. """ # load the json file - with open(json_path, 'r') as f: + with open(json_path, "r") as f: index = json.load(f) # assign attributes if exists @@ -75,7 +75,7 @@ def export(self, json_path: str): index["weight_map"] = self.weight_map # export the index file - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(index, f, indent=4) def append_weight_map(self, param_name: str, shard_file: str): diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 664ac63e45ac..c22b76dd46f7 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,5 +1,4 @@ # coding=utf-8 -import copy import os import re from collections import abc as container_abcs @@ -12,7 +11,7 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import OptimizerWrapper from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -55,7 +54,6 @@ def is_safetensors_available() -> bool: bool: whether safetensors is available. """ try: - import safetensors return True except ImportError: return False @@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a dtensor checkpoint. """ - if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"): return True else: return False @@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a safetensor checkpoint. """ - if checkpoint_file_path.endswith('.safetensors'): + if checkpoint_file_path.endswith(".safetensors"): return True else: return False @@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz partition_dim = dim break if partition_dim is not None: - assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ - f"The parameter isn't evenly distributed among tensor parallel group: \ + assert ( + original_shape[partition_dim] == tp_size * current_shape[partition_dim] + ), f"The parameter isn't evenly distributed among tensor parallel group: \ shape before sharding {original_shape}, shape after sharding {current_shape}" return partition_dim @@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz # Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): - ''' + """ Unwrap a wrapped optimizer. This method should be used before saving/loading it to/from sharded checkpoints. - ''' + """ unwrapped_optim = optimizer.optim return unwrapped_optim class StateDictSharder: - def __init__(self, size_per_shard: int) -> None: self.max_shard_size = size_per_shard self.current_block = OrderedDict() self.current_block_size = 0 def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -159,13 +156,11 @@ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[Ordere return ret_block, ret_block_size def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: - # A state might contain more than one tensors. # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' state_size = 0 isDTensor = False for state_tensor in state.values(): - # When state_tensor is not of Tensor class, # e.g., a SGD optimizer with momentum set to 0 can have None as state # The calculation of tensor size should be skipped to avoid error. @@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to return param_ -def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], - checkpoint: str, - index_file: "CheckpointIndexFile", - base_filename: str, - is_master: bool, - use_safetensors: bool = False, - use_pp_format: bool = False) -> int: - ''' +def save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_safetensors: bool = False, + use_pp_format: bool = False, +) -> int: + """ Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. @@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] Returns: int: the total size of shards - ''' + """ total_size = 0 shard_filenames = [] @@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> """ # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. - states = state_dict['state'] + states = state_dict["state"] state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): @@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors """ if use_safetensors: assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." + assert checkpoint_file_path.endswith( + ".safetensors" + ), "safetensors only supports .safetensors suffix for checkpoint file." from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) @@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) -def clean_folder(checkpoint_path: str, - weights_name: str, - shard_filenames: List[str], - is_master: bool = True, - use_pp_format: bool = False): +def clean_folder( + checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False, +): """ Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. @@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str, else: # When this checkpoint is created by pipeline parallel process, the pattern is a little different. reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") - if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) - and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shard_filenames + and reg.fullmatch(filename_no_suffix) is not None + ): os.remove(full_filename) @@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi size_per_shard (int): size per shard in MB. """ root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') + output_root_path = root_path.joinpath("dtensor") # create directory output_root_path.mkdir(exist_ok=True) @@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi # update the weight map # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors) index_file.append_weight_map(name, ckpt_file_name_in_weight_map) @@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str: str: checkpoint file suffix. """ if use_safetensors: - return '.safetensors' + return ".safetensors" else: - return '.bin' + return ".bin" -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: +def generate_checkpoint_shard_file_name( + index: int, total_number: int, use_safetensors: bool, prefix: str = None +) -> str: """ Generate checkpoint shard file name. @@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo str: dtensor file name. """ suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' + return f"{param_name}.{index}.{suffix}" # ======================================== @@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): if use_safetensors: from safetensors.torch import load_file as safe_load_file from safetensors.torch import safe_open + with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() if metadata["format"] != "pt": raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) return safe_load_file(checkpoint_file) else: - return torch.load(checkpoint_file, map_location=torch.device('cpu')) + return torch.load(checkpoint_file, map_location=torch.device("cpu")) -def load_state_dict_into_model(model: nn.Module, - state_dict: torch.Tensor, - missing_keys: List, - strict: bool = False, - load_sub_module: bool = True): +def load_state_dict_into_model( + model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True +): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module, error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = OrderedDict(state_dict) if metadata is not None: state_dict._metadata = metadata @@ -560,10 +564,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) if strict: if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + error_msgs = "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: @@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str # Load list of param_groups from given file path. # The params in saved_groups are in the form of integer indices. - saved_groups = torch.load(param_group_path, map_location=torch.device('cpu')) + saved_groups = torch.load(param_group_path, map_location=torch.device("cpu")) if not isinstance(saved_groups, List): - raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') + raise ValueError(f"The param_groups saved at {param_group_path} is not of List type") # The params in param_groups are in the form of pytorch tensors. # For more details, please view source code of Optimizer class in pytorch. @@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str # Check the compatibility of saved_groups and param_groups. if len(param_groups) != len(saved_groups): raise ValueError("loaded state dict has a different number of original parameter groups") - param_lens = (len(g['params']) for g in param_groups) - saved_lens = (len(g['params']) for g in saved_groups) + param_lens = (len(g["params"]) for g in param_groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) # Creating mapping from id to parameters. id_map = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in param_groups))) + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in param_groups)), + ) } # Update parameter groups, setting their 'params' value. def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] - optimizer.__dict__.update({'param_groups': updated_groups}) + optimizer.__dict__.update({"param_groups": updated_groups}) return id_map @@ -628,7 +638,7 @@ def cast(param, value, key=None): # Floating-point types are a bit special here. They are the only ones # that are assumed to always match the type of params. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 - if (key != "step"): + if key != "step": if param.is_floating_point(): value = value.to(param.dtype) value = value.to(param.device) @@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): """ # Do the cleaning up as in src code of Pytorch. - optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. - optimizer.defaults.setdefault('differentiable', False) + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault("differentiable", False) def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: @@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: return False, None elif checkpoint_path.is_dir(): # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.*json')) + index_files = list(checkpoint_path.glob("*.index.*json")) # if we found a .index.json file, make sure there is only one if len(index_files) > 0: - assert len( - index_files - ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + assert ( + len(index_files) == 1 + ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" if len(index_files) == 1: return True, index_files[0] else: return False, None else: - raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.') + raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") def load_state_dict(checkpoint_file_path: Path): @@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path): dict: state dict. """ - assert not is_dtensor_checkpoint(checkpoint_file_path), \ - f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + assert not is_dtensor_checkpoint( + checkpoint_file_path + ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline." if is_safetensor_checkpoint(checkpoint_file_path): - assert is_safetensors_available(), \ - f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + assert ( + is_safetensors_available() + ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors." # load with safetensors from safetensors import safe_open + state_dict = {} with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: for k in f.keys(): @@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch - return torch.load(checkpoint_file_path, map_location=torch.device('cpu')) + return torch.load(checkpoint_file_path, map_location=torch.device("cpu")) def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py index 658e35e4c72e..c7cb19c19308 100644 --- a/colossalai/cli/__init__.py +++ b/colossalai/cli/__init__.py @@ -1,3 +1,3 @@ from .cli import cli -__all__ = ['cli'] +__all__ = ["cli"] diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py index a86b32bb6a18..7c26ab6ade6c 100644 --- a/colossalai/cli/check/__init__.py +++ b/colossalai/cli/check/__init__.py @@ -1,11 +1,12 @@ import click + from .check_installation import check_installation -__all__ = ['check'] +__all__ = ["check"] @click.command(help="Check if Colossal-AI is correct based on the given option") -@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") +@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly") def check(installation): if installation: check_installation() diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index 4a481f3bd122..772c513ffa06 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -9,7 +9,7 @@ def to_click_output(val): # installation check output to understandable symbols for readability - VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} + VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"} if val in VAL_TO_SYMBOL: return VAL_TO_SYMBOL[val] @@ -55,8 +55,8 @@ def check_installation(): else: torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) - click.echo(f'#### Installation Report ####') - click.echo(f'\n------------ Environment ------------') + click.echo(f"#### Installation Report ####") + click.echo(f"\n------------ Environment ------------") click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}") click.echo(f"PyTorch version: {to_click_output(torch_version)}") click.echo(f"System CUDA version: {to_click_output(cuda_version)}") @@ -69,7 +69,7 @@ def check_installation(): f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." ) - click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') + click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------") click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") @@ -81,7 +81,7 @@ def check_installation(): click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") click.echo(f"\n------------ Compatibility ------------") - click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') + click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}") click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") click.echo(f"") @@ -106,12 +106,12 @@ def _is_compatible(versions): return False # split version into [major, minor, patch] - versions = [version.split('.') for version in versions] + versions = [version.split(".") for version in versions] for version in versions: if len(version) == 2: # x means unknown - version.append('x') + version.append("x") for idx, version_values in enumerate(zip(*versions)): equal = len(set(version_values)) == 1 @@ -137,11 +137,11 @@ def _parse_colossalai_version(): # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) # 2. X.X.X (when colossalai is not installed with CUDA extensions) # where X represents an integer. - colossalai_version = colossalai.__version__.split('+')[0] + colossalai_version = colossalai.__version__.split("+")[0] try: - torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] - cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] + torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0] + cuda_version_for_aot_build = colossalai.__version__.split("cu")[1] except: torch_version_for_aot_build = None cuda_version_for_aot_build = None @@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed(): JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. """ try: - import colossalai._C.fused_optim found_aot_cuda_ext = True except ImportError: found_aot_cuda_ext = False @@ -175,14 +174,14 @@ def _check_torch_version(): # torch version can be of two formats # - 1.13.1+cu113 # - 1.13.1.devxxx - torch_version = torch.__version__.split('+')[0] - torch_version = '.'.join(torch_version.split('.')[:3]) + torch_version = torch.__version__.split("+")[0] + torch_version = ".".join(torch_version.split(".")[:3]) # get cuda version in pytorch build try: torch_cuda_major = torch.version.cuda.split(".")[0] torch_cuda_minor = torch.version.cuda.split(".")[1] - torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}" except: torch_cuda_version = None @@ -208,7 +207,7 @@ def _check_cuda_version(): release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] - cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + cuda_version = f"{bare_metal_major}.{bare_metal_minor}" except: cuda_version = None return cuda_version diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index 0dea7c504957..0d94fe59f8ae 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -4,8 +4,7 @@ from .launcher import run -class Arguments(): - +class Arguments: def __init__(self, arg_dict): for k, v in arg_dict.items(): self.__dict__[k] = v @@ -19,5 +18,5 @@ def cli(): cli.add_command(run) cli.add_command(check) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py index 808e4e84574f..0f9ead6495db 100644 --- a/colossalai/cli/launcher/__init__.py +++ b/colossalai/cli/launcher/__init__.py @@ -5,56 +5,81 @@ from .run import launch_multi_processes -@click.command(help="Launch distributed training on a single node or multiple nodes", - context_settings=dict(ignore_unknown_options=True)) -@click.option("-H", - "-host", - "--host", - type=str, - default=None, - help="the list of hostnames to launch in the format ,") +@click.command( + help="Launch distributed training on a single node or multiple nodes", + context_settings=dict(ignore_unknown_options=True), +) +@click.option( + "-H", + "-host", + "--host", + type=str, + default=None, + help="the list of hostnames to launch in the format ,", +) @click.option( "--hostfile", type=str, default=None, - help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") -@click.option("--include", - type=str, - default=None, - help="Specify computing devices to use during execution. String format is ,," - " only effective when used with --hostfile.") + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname", +) +@click.option( + "--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.", +) @click.option( "--exclude", type=str, default=None, - help= - "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," - " only effective when used with --hostfile.") -@click.option("--num_nodes", - type=int, - default=-1, - help="Total number of worker nodes to use, only effective when used with --hostfile.") + help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," + " only effective when used with --hostfile.", +) +@click.option( + "--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.", +) @click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") -@click.option("--master_port", - type=int, - default=29500, - help="(optional) Port used by PyTorch distributed for communication during distributed training.") -@click.option("--master_addr", - type=str, - default="127.0.0.1", - help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") +@click.option( + "--master_port", + type=int, + default=29500, + help="(optional) Port used by PyTorch distributed for communication during distributed training.", +) +@click.option( + "--master_addr", + type=str, + default="127.0.0.1", + help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.", +) @click.option( "--extra_launch_args", type=str, default=None, - help= - "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " - "This will be converted to --arg1=1 --arg2=2 during execution") + help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution", +) @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.argument("user_script", type=str) -@click.argument('user_args', nargs=-1) -def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, - master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: +@click.argument("user_args", nargs=-1) +def run( + host: str, + hostfile: str, + num_nodes: int, + nproc_per_node: int, + include: str, + exclude: str, + master_addr: str, + master_port: int, + extra_launch_args: str, + ssh_port: int, + user_script: str, + user_args: str, +) -> None: """ To launch multiple processes on a single node or multiple nodes via command line. @@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: # run with hostfile excluding the hosts selected colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py """ - if not user_script.endswith('.py'): - click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + if not user_script.endswith(".py"): + click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help") exit() args_dict = locals() diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py index 2a6a111e4d72..684f64f59d28 100644 --- a/colossalai/cli/launcher/hostinfo.py +++ b/colossalai/cli/launcher/hostinfo.py @@ -1,5 +1,4 @@ import socket -from typing import List class HostInfo: @@ -34,7 +33,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None: """ if port is None: - port = 22 # no port specified, lets just use the ssh port + port = 22 # no port specified, lets just use the ssh port # socket.getfqdn("127.0.0.1") does not return localhost # on some users' machines @@ -50,7 +49,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None: return localaddrs == targetaddrs def __str__(self): - return f'hostname: {self.hostname}, port: {self.port}' + return f"hostname: {self.hostname}, port: {self.port}" def __repr__(self): return self.__str__() diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py index 85b241e96292..99c4db406844 100644 --- a/colossalai/cli/launcher/multinode_runner.py +++ b/colossalai/cli/launcher/multinode_runner.py @@ -7,8 +7,13 @@ from .hostinfo import HostInfo, HostInfoList -def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, - send_conn: mp_connection.Connection, env: dict) -> None: +def run_on_host( + hostinfo: HostInfo, + workdir: str, + recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, + env: dict, +) -> None: """ Use fabric connection to execute command on local or remote hosts. @@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) finish = False - env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()]) # keep listening until exit while not finish: # receive cmd cmds = recv_conn.recv() - if cmds == 'exit': + if cmds == "exit": # exit from the loop finish = True break @@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne else: # execute on the remote machine fab_conn.run(cmds, hide=False) - send_conn.send('success') + send_conn.send("success") except Exception as e: click.echo( f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" ) - send_conn.send('failure') + send_conn.send("failure") # shutdown send_conn.send("finish") @@ -96,8 +101,7 @@ def send(self, hostinfo: HostInfo, cmd: str) -> None: cmd (str): the command to execute """ - assert hostinfo.hostname in self.master_send_conns, \ - f'{hostinfo} is not found in the current connections' + assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections" conn = self.master_send_conns[hostinfo.hostname] conn.send(cmd) @@ -107,7 +111,7 @@ def stop_all(self) -> None: """ for hostname, conn in self.master_send_conns.items(): - conn.send('exit') + conn.send("exit") def recv_from_all(self) -> dict: """ diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index d2d02811ac9d..7ca8ee90386c 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -12,7 +12,7 @@ from .multinode_runner import MultiNodeRunner # Constants that define our syntax -NODE_SEP = ',' +NODE_SEP = "," def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: @@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") exit() - with open(hostfile_path, 'r') as fd: + with open(hostfile_path, "r") as fd: device_pool = HostInfoList() for line in fd.readlines(): line = line.strip() - if line == '': + if line == "": # skip empty lines continue @@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: - '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + """Parse an inclusion or exclusion string and filter a hostfile dictionary. Examples: include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. @@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str Returns: filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion - ''' + """ # Ensure include/exclude are mutually exclusive if include_str and exclude_str: @@ -136,16 +136,16 @@ def _arg_dict_to_list(arg_dict): for k, v in arg_dict.items(): if v: - ret.append(f'--{k}={v}') + ret.append(f"--{k}={v}") else: - ret.append(f'--{k}') + ret.append(f"--{k}") return ret if extra_launch_args: extra_launch_args_dict = dict() - for arg in extra_launch_args.split(','): - if '=' in arg: - k, v = arg.split('=') + for arg in extra_launch_args.split(","): + if "=" in arg: + k, v = arg.split("=") extra_launch_args_dict[k] = v else: extra_launch_args_dict[arg] = None @@ -158,9 +158,14 @@ def _arg_dict_to_list(arg_dict): if torch_version.minor < 9: cmd = [ - sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", - f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", - f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.launch", + f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", + f"--master_port={master_port}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: # extra launch args for torch distributed launcher with torch >= 1.9 @@ -174,17 +179,24 @@ def _arg_dict_to_list(arg_dict): if torch_version.minor < 10: cmd = [ - sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", - f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: cmd = [ - "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + "torchrun", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] cmd += _arg_dict_to_list(default_torchrun_rdzv_args) cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args - cmd = ' '.join(cmd) + cmd = " ".join(cmd) return cmd @@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None: # run on local node if not hosts or hostfile is given # add local node to host info list active_device_pool = HostInfoList() - localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port) active_device_pool.append(localhost_info) # launch distributed processes runner = MultiNodeRunner() - curr_path = os.path.abspath('.') + curr_path = os.path.abspath(".") # collect current path env env = dict() for k, v in os.environ.items(): # do not support multi-line env var - if v and '\n' not in v: + if v and "\n" not in v: env[k] = v # establish remote connection @@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None: # execute distributed launching command for node_id, hostinfo in enumerate(active_device_pool): - cmd = get_launch_command(master_addr=args.master_addr, - master_port=args.master_port, - nproc_per_node=args.nproc_per_node, - user_script=args.user_script, - user_args=args.user_args, - node_rank=node_id, - num_nodes=len(active_device_pool), - extra_launch_args=args.extra_launch_args) + cmd = get_launch_command( + master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args, + ) runner.send(hostinfo=hostinfo, cmd=cmd) # start training diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py index 44f571ca2501..b8176feb647b 100644 --- a/colossalai/cluster/__init__.py +++ b/colossalai/cluster/__init__.py @@ -3,4 +3,4 @@ from .process_group_manager import ProcessGroupManager from .process_group_mesh import ProcessGroupMesh -__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] +__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"] diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 8754baa19792..e35aca5f4d7e 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -10,13 +10,14 @@ @dataclass class DeviceMeshInfo: - ''' + """ This class is used to store the information used to initialize the device mesh. Args: physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. - ''' + """ + physical_ids: List[int] mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None @@ -24,16 +25,18 @@ def __post_init__(self): if self.mesh_shape is not None: world_size = len(self.physical_ids) mesh_shape_numel = torch.Size(self.mesh_shape).numel() - assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + assert ( + world_size == mesh_shape_numel + ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}" def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): - ''' + """ This method is used to initialize the device mesh. Args: device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. - ''' + """ # parse the device mesh info physical_devices = device_mesh_info.physical_ids physical_mesh = torch.tensor(physical_devices) @@ -67,13 +70,13 @@ def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMe Args: name (str): name of the device mesh device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh - """ + """ if name not in self.device_mesh_store: device_mesh = initialize_device_mesh(device_mesh_info) self.device_mesh_store[name] = device_mesh return device_mesh else: - raise ValueError(f'Device mesh {name} already exists.') + raise ValueError(f"Device mesh {name} already exists.") def get(self, name: str) -> DeviceMesh: """ @@ -88,7 +91,7 @@ def get(self, name: str) -> DeviceMesh: if name in self.device_mesh_store: return self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -103,7 +106,7 @@ def destroy(self, name: str) -> None: dist.destroy_process_group(pg) del self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy_all(self): """ diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 3ee364ec3364..5b66e88717ba 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -36,12 +36,13 @@ class in the whole program. """ def __init__(self): - assert dist.is_initialized( - ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' + assert ( + dist.is_initialized() + ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first." self._rank = dist.get_rank() self._world_size = dist.get_world_size() # this is often passed by launchers such as torchrun - self._local_rank = os.environ.get('LOCAL_RANK', -1) + self._local_rank = os.environ.get("LOCAL_RANK", -1) @property def rank(self) -> int: @@ -59,7 +60,9 @@ def _assert_local_rank_set(self): """ Assert that the local rank is set. This is often passed by launchers such as torchrun. """ - assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' + assert ( + self.local_rank >= 0 + ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process." def is_master(self, process_group: ProcessGroup = None) -> bool: """ @@ -183,7 +186,6 @@ def on_master_only(self, process_group: ProcessGroup = None): # define an inner function def decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): if is_master: diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py index e52661846f3e..68106b503126 100644 --- a/colossalai/cluster/process_group_manager.py +++ b/colossalai/cluster/process_group_manager.py @@ -19,7 +19,7 @@ class ProcessGroupManager: def __init__(self): self.pg_store = dict() - def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: + def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup: """ Get a process group by name. If the process group does not exist, it will be created. @@ -36,7 +36,7 @@ def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl self.pg_store[name] = pg return pg else: - raise ValueError(f'Process group {name} already exists.') + raise ValueError(f"Process group {name} already exists.") def get(self, name: str) -> ProcessGroup: """ @@ -51,7 +51,7 @@ def get(self, name: str) -> ProcessGroup: if name in self.pg_store: return self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -64,7 +64,7 @@ def destroy(self, name: str) -> None: dist.destroy_process_group(self.pg_store[name]) del self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy_all(self) -> None: """ diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 623160003767..3885bc962561 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,7 +94,7 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int: """Convert a coordinate to a rank. mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. with wrap, index out of range would be wrapped around. @@ -141,8 +141,9 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: return list(self._group_to_ranks[group]) @staticmethod - def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, - indices_at_axis: List[int]) -> List[Tuple[int, ...]]: + def get_coords_along_axis( + base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + ) -> List[Tuple[int, ...]]: """Get coordinates along the given axis. Args: @@ -155,13 +156,12 @@ def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, """ coords_in_group = [] for idx in indices_at_axis: - coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) return coords_in_group - def create_group_along_axis(self, - axis: int, - indices_at_axis: Optional[List[int]] = None, - backend: Optional[str] = None) -> ProcessGroup: + def create_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. Args: @@ -186,10 +186,9 @@ def create_group_along_axis(self, target_group = group return target_group - def get_group_along_axis(self, - axis: int, - indices_at_axis: Optional[List[int]] = None, - backend: Optional[str] = None) -> ProcessGroup: + def get_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. Args: diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index eb6d5d05a008..ab57301bb910 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -3,6 +3,6 @@ # from .moe_context import MOE_CONTEXT __all__ = [ - 'Config', - 'ConfigException', + "Config", + "ConfigException", ] diff --git a/colossalai/context/config.py b/colossalai/context/config.py index 8903707708df..05a2e4bf044a 100644 --- a/colossalai/context/config.py +++ b/colossalai/context/config.py @@ -5,6 +5,7 @@ import sys from importlib.machinery import SourceFileLoader from pathlib import Path + from colossalai.logging import get_dist_logger @@ -41,7 +42,7 @@ def _add_item(self, key, value): self.__setattr__(key, value) def update(self, config): - assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' + assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." for k, v in config.items(): self._add_item(k, v) return self @@ -66,11 +67,11 @@ def from_file(filename: str): elif isinstance(filename, Path): filepath = filename.absolute() - assert filepath.exists(), f'{filename} is not found, please check your configuration path' + assert filepath.exists(), f"{filename} is not found, please check your configuration path" # check extension extension = filepath.suffix - assert extension == '.py', 'only .py files are supported' + assert extension == ".py", "only .py files are supported" # import the config as module remove_path = False @@ -86,13 +87,13 @@ def from_file(filename: str): config = Config() for k, v in module.__dict__.items(): - if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): continue else: config._add_item(k, v) logger = get_dist_logger() - logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') + logger.debug("variables which starts with __, is a module or class declaration are omitted in config file") # remove module del sys.modules[module_name] diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b6e3b52017b2..066dfc7222e1 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -9,14 +9,13 @@ def _check_sanity(): from colossalai.legacy.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " - "pipeline parallel at present.") + raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" def __init__(self, ep_size: int, dp_size: int): _check_sanity() @@ -61,9 +60,11 @@ def setup(self, seed: int, use_kernel_optim: bool = True): self.world_size = dist.get_world_size() from colossalai.legacy.core import global_context as gpc - self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) - assert self.world_size % self.max_ep_size == 0, \ - "Maximum expert parallel size must be a factor of the number of GPUs" + + self.max_ep_size = gpc.config.get("max_ep_size", self.world_size) + assert ( + self.world_size % self.max_ep_size == 0 + ), "Maximum expert parallel size must be a factor of the number of GPUs" self.min_dp_size = self.world_size // self.max_ep_size # Enabling kernel optimization may raise error in some cases @@ -71,6 +72,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True): self.use_kernel_optim = use_kernel_optim from .random import moe_set_seed + moe_set_seed(seed) self.has_setup = True @@ -88,11 +90,13 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." + assert gt_flag or lt_flag, ( + "Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa." + ) # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # there are multiple experts in each GPU and each GPU has different experts diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index 8ca335119d52..3088b0dffaac 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -16,6 +16,7 @@ def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance else: - assert len(args) == 0 and len( - kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." return cls._instances[cls] diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index 689189998c3f..34a7d2526fda 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -1,4 +1,4 @@ from .alpha_beta_profiler import AlphaBetaProfiler from .calc_pipeline_strategy import alpa_dp -__all__ = ['AlphaBetaProfiler', 'alpa_dp'] +__all__ = ["AlphaBetaProfiler", "alpa_dp"] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py index f4e6cfffbcdf..88520b2a14d0 100644 --- a/colossalai/device/alpha_beta_profiler.py +++ b/colossalai/device/alpha_beta_profiler.py @@ -13,7 +13,7 @@ class AlphaBetaProfiler: - ''' + """ Profile alpha and beta value for a given device list. Usage: @@ -27,17 +27,19 @@ class AlphaBetaProfiler: (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} - ''' - - def __init__(self, - physical_devices: List[int], - alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, - ctype: str = 'a', - warmup: int = 5, - repeat: int = 25, - latency_iters: int = 5, - homogeneous_tolerance: float = 0.1): - ''' + """ + + def __init__( + self, + physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, + ctype: str = "a", + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1, + ): + """ Args: physical_devices: A list of device id, each element inside it is the global rank of that device. alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. @@ -45,7 +47,7 @@ def __init__(self, warmup: Number of warmup iterations. repeat: Number of iterations to measure. latency_iters: Number of iterations to measure latency. - ''' + """ self.physical_devices = physical_devices self.ctype = ctype self.world_size = len(physical_devices) @@ -123,7 +125,7 @@ def _profile(self, process_group, pg_handler, nbytes): return (None, None) def profile_latency(self, process_group, pg_handler): - ''' + """ This function is used to profile the latency of the given process group with a series of bytes. Args: @@ -132,7 +134,7 @@ def profile_latency(self, process_group, pg_handler): Returns: latency: None if the latency is not measured, otherwise the median of the latency_list. - ''' + """ latency_list = [] for i in range(self.latency_iters): nbytes = int(BYTE << i) @@ -148,26 +150,26 @@ def profile_latency(self, process_group, pg_handler): return latency def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): - ''' + """ This function is used to profile the bandwidth of the given process group. Args: process_group: A tuple of global rank of the process group. pg_handler: The handler of the process group. - ''' + """ (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) return bandwidth def profile_ab(self): - ''' + """ This method is used to profiling the alpha and beta value for a given device list. Returns: alpha_beta_dict: A dict which maps process group to its alpha and beta value. - ''' + """ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} rank = dist.get_rank() - global_pg_handler = dist.new_group(self.physical_devices) + dist.new_group(self.physical_devices) def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): assert rank in process_group @@ -208,7 +210,7 @@ def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): return alpha_beta_dict def search_best_logical_mesh(self): - ''' + """ This method is used to search the best logical mesh for the given device list. The best logical mesh is searched in following steps: @@ -232,19 +234,19 @@ def search_best_logical_mesh(self): >>> best_logical_mesh = profiler.search_best_logical_mesh() >>> print(best_logical_mesh) [[0, 1], [2, 3]] - ''' + """ def _power_of_two(integer): return integer & (integer - 1) == 0 def _detect_homogeneous_device(alpha_beta_dict): - ''' + """ This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] * base_beta. - ''' + """ homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} for process_group, (_, beta) in alpha_beta_dict.items(): if homogeneous_device_dict is None: @@ -254,7 +256,8 @@ def _detect_homogeneous_device(alpha_beta_dict): match_beta = None for beta_value in homogeneous_device_dict.keys(): if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( - 1 - self.homogeneous_tolerance): + 1 - self.homogeneous_tolerance + ): match_beta = beta_value break @@ -267,9 +270,9 @@ def _detect_homogeneous_device(alpha_beta_dict): return homogeneous_device_dict def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to check whether the homogeneous_group contains all physical devices. - ''' + """ flatten_mesh = [] for process_group in homogeneous_group: flatten_mesh.extend(process_group) @@ -277,9 +280,9 @@ def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): return len(non_duplicated_flatten_mesh) == len(self.physical_devices) def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to construct the largest ring in the homogeneous_group for each rank. - ''' + """ # Construct the ring ring = [] ranks_in_ring = [] @@ -300,7 +303,9 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): check_rank = check_rank_list.pop() for process_group in homogeneous_group: if check_rank in process_group: - rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + rank_to_append = ( + process_group[0] if process_group[1] == check_rank else process_group[1] + ) if rank_to_append not in ring_for_rank: stable_status = False rank_to_check_list.append(rank_to_append) @@ -314,7 +319,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): assert _power_of_two(self.world_size) power_of_two = int(math.log2(self.world_size)) median = power_of_two // 2 - balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median)) row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] balanced_logical_mesh = [] for row_index in range(row_size): @@ -348,7 +353,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): return best_logical_mesh def extract_alpha_beta_for_device_mesh(self): - ''' + """ Extract the mesh_alpha list and mesh_beta list based on the best logical mesh, which will be used to initialize the device mesh. @@ -360,7 +365,7 @@ def extract_alpha_beta_for_device_mesh(self): [2.5917552411556242e-05, 0.00010312341153621673] >>> print(mesh_beta) [5.875573704655635e-11, 4.7361584445959614e-12] - ''' + """ best_logical_mesh = self.search_best_logical_mesh() first_axis = [row[0] for row in best_logical_mesh] diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py index 4ab72dfe60f0..72d432701ada 100644 --- a/colossalai/device/calc_pipeline_strategy.py +++ b/colossalai/device/calc_pipeline_strategy.py @@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): while i <= num_devices_per_host: i *= 2 p += 1 - assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " - f"while now num_devices_per_host = {num_devices_per_host}") + assert pow(2, p) == num_devices_per_host, ( + "Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}" + ) if mode == "alpa": for i in range(p + 1): submesh_choices.append((1, pow(2, i))) @@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): return submesh_choices -def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, - best_configs): +def alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs +): """Implementation of Alpa DP for pipeline strategy - Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf - Arguments: - num_layers: K - num_devices: N*M - num_microbatches: B - submesh_choices: List[(n_i,m_i)] - compute_cost: t_intra - """ + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ # For f, layer ID start from 0 # f[#pipeline stages, layer id that is currently being considered, number of devices used] f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) @@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com for i in range(num_layers, k, -1): stage_cost = compute_cost[k, i, m] new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost - if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + if stage_cost <= max_stage_cost and new_cost < f[s, k, d]: f[s, k, d] = new_cost f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) @@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com res = [] while current_s > 0 and current_layer < num_layers and current_devices > 0: - next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices] assert next_start_layer != -1 and current_devices != -1 res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_start_layer current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) - assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + assert current_s == 0 and current_layer == num_layers and current_devices == 0 return total_cost, res -def alpa_dp(num_layers, - num_devices, - num_microbatches, - submesh_choices, - num_autosharding_configs, - compute_cost, - gap=1e-6): +def alpa_dp( + num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6 +): """Alpa auto stage dynamic programming. - Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py Arguments: submesh_choices: List[(int,int)] num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) """ - assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), - num_autosharding_configs), "Cost shape wrong." + assert np.shape(compute_cost) == ( + num_layers, + num_layers, + len(submesh_choices), + num_autosharding_configs, + ), "Cost shape wrong." all_possible_stage_costs = np.sort(np.unique(compute_cost)) best_cost = np.inf best_solution = None @@ -117,8 +120,9 @@ def alpa_dp(num_layers, break if max_stage_cost - last_max_stage_cost < gap: continue - cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, - max_stage_cost, best_configs) + cost, solution = alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs + ) if cost < best_cost: best_cost = cost best_solution = solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index f41af1161be1..72f199203a9d 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -40,14 +40,16 @@ class DeviceMesh: _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} - def __init__(self, - physical_mesh_id: torch.Tensor, - mesh_shape: torch.Size = None, - logical_mesh_id: torch.Tensor = None, - mesh_alpha: List[float] = None, - mesh_beta: List[float] = None, - init_process_group: bool = False, - device: str = 'cuda'): + def __init__( + self, + physical_mesh_id: torch.Tensor, + mesh_shape: torch.Size = None, + logical_mesh_id: torch.Tensor = None, + mesh_alpha: List[float] = None, + mesh_beta: List[float] = None, + init_process_group: bool = False, + device: str = "cuda", + ): # ============================ # Physical & Logical Mesh IDs # ============================ @@ -57,9 +59,10 @@ def __init__(self, # logical mesh ids can be obtained via two ways # 1. provide physical mesh id and provide mesh shape # 2. directly supply the logical mesh id - assert mesh_shape is None or logical_mesh_id is None, \ - "Only one of mesh_shape and logical_mesh_id can be specified." \ + assert mesh_shape is None or logical_mesh_id is None, ( + "Only one of mesh_shape and logical_mesh_id can be specified." "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id" + ) if logical_mesh_id is None: self._mesh_shape = mesh_shape @@ -71,12 +74,15 @@ def __init__(self, # ensure two things: # 1. logical and physical mesh IDs should contain the same elements # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed - assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ - "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." - assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ - "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." - assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ - "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." + assert torch.equal( + torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id) + ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert ( + torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel() + ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert ( + torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel() + ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." # =============================================== # coefficient for alpha-beta communication model @@ -92,8 +98,9 @@ def __init__(self, self.mesh_beta = tuple(mesh_beta) # ensure the alpha and beta have the same shape - assert len(self.mesh_alpha) == len(self.mesh_beta), \ - "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + assert len(self.mesh_alpha) == len( + self.mesh_beta + ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." # ========================= # Device for Process Group @@ -109,8 +116,9 @@ def __init__(self, # : [ , , , ...] # } self._global_to_local_rank_mapping = dict() - self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, - tensor=self.logical_mesh_id) + self._init_global_to_logical_rank_mapping( + mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id + ) # create process group self._process_group_dict = {} @@ -194,8 +202,9 @@ def _get_device_by_backend(process_group): device_list = [_get_device_by_backend(pg) for pg in process_group] # make sure all devices are the same - assert all([device == device_list[0] for device in device_list]), \ - "All devices should be the same, please check your input process groups are created with the same distributed backend." + assert all( + [device == device_list[0] for device in device_list] + ), "All devices should be the same, please check your input process groups are created with the same distributed backend." # create a fake physical mesh id # as we only get the process group associated with the current process, @@ -270,7 +279,7 @@ def __deepcopy__(self, memo) -> "DeviceMesh": result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != '_process_group_dict': + if k != "_process_group_dict": setattr(result, k, __import__("copy").deepcopy(v, memo)) else: # process group cannot be copied @@ -278,10 +287,9 @@ def __deepcopy__(self, memo) -> "DeviceMesh": setattr(result, k, v) return result - def _init_global_to_logical_rank_mapping(self, - mapping: Dict, - tensor: torch.Tensor, - index_list: List[int] = []) -> Dict[int, List[int]]: + def _init_global_to_logical_rank_mapping( + self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = [] + ) -> Dict[int, List[int]]: """ Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. @@ -311,15 +319,19 @@ def _init_global_to_logical_rank_mapping(self, self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) def init_logical_process_group(self): - ''' + """ This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. - ''' + """ # sanity check - assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" - assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + assert ( + dist.is_initialized + ), "The torch.distributed should be initialized before calling init_logical_process_group" + assert ( + not self._is_initialized + ), "The logical process group has been initialized, do not call init_logical_process_group twice" # update the global rank of the current process self._global_rank_of_current_process = dist.get_rank() @@ -389,7 +401,7 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i return local_ranks def _collate_global_ranks_in_same_process_group(self, global_rank): - ''' + """ Give a global rank and return all global ranks involved in its associated process group in each axis. Example: @@ -414,7 +426,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): 0: [0, 4, 8, 12], 1: [0, 1, 2, 3] # } - ''' + """ # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping # for self._global_to_local_rank_mapping # the key is the global rank @@ -437,7 +449,6 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): # in the same process group in the given axis # the _local_rank refers to the local rank of the current process for _local_rank in range(self.logical_mesh_id.shape[dim]): - # if this dimension is not initialized yet, # initialize it with an empty array if dim not in processes_in_the_same_process_group: @@ -478,29 +489,37 @@ def flatten(self): flatten_mesh_shape_size = len(self._mesh_shape) flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self._physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self._init_process_group) + return DeviceMesh( + self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group, + ) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.1) + return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1 def all_reduce_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + - 0.01) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + + 0.01 + ) def reduce_scatter_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.001) + return ( + self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001 + ) def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + + 0.001 + ) diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 0444a4816273..4d40d5badfd0 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -2,16 +2,14 @@ import torch -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: - from . import _meta_regist_12 META_COMPATIBILITY = True elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: - from . import _meta_regist_13 META_COMPATIBILITY = True elif TORCH_MAJOR == 2: META_COMPATIBILITY = True @@ -36,7 +34,7 @@ def decorator(func): else: def wrapper(*args, **kwargs): - raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}") return wrapper diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py index 52e8d63ae543..63f88682e85a 100644 --- a/colossalai/fx/_meta_regist_12.py +++ b/colossalai/fx/_meta_regist_12.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch.utils._pytree import tree_map @@ -16,13 +16,11 @@ def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -48,7 +46,6 @@ def meta_conv( output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -125,7 +122,8 @@ def calc_conv_nd_return_shape( kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -159,22 +157,42 @@ def pick_memory_format(): shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) -def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, +): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) -def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): - return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') +def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta") # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -208,7 +226,6 @@ def meta_cuda_rnn( batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -224,8 +241,11 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -242,18 +262,20 @@ def meta_cuda_rnn( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) -def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): +def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, +): print(input, weight, hx, cx) grad_input = torch.empty_like(input) grad_weight = torch.empty_like(weight) grad_hx = torch.empty_like(hx) - grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta") return grad_input, grad_weight, grad_hx, grad_cx @@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, - save_invstd, train, eps, output_mask): +def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') - reserve = torch.empty((0), dtype=torch.uint8, device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") + reserve = torch.empty((0), dtype=torch.uint8, device="meta") return output, running_mean, running_var, reserve @@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) -def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): +def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((bs, n_input, 1), device='meta') - running_var = torch.empty((bs, n_input, 1), device='meta') + running_mean = torch.empty((bs, n_input, 1), device="meta") + running_var = torch.empty((bs, n_input, 1), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): +def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) @@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) @@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices): # ============================== Embedding ========================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) -def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): - return torch.empty((num_weights, grad_output.size(-1)), - dtype=grad_output.dtype, - device=grad_output.device, - layout=grad_output.layout) +def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq +): + return torch.empty( + (num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout, + ) # ============================== Dropout =========================================== diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 33b164800262..dfb5754d71c1 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple import torch @@ -18,6 +18,7 @@ magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = True except: from torch.fx.graph import ( @@ -32,12 +33,13 @@ magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: - __all__ = ['ActivationCheckpointCodeGen'] + __all__ = ["ActivationCheckpointCodeGen"] else: - __all__ = ['python_code_with_activation_checkpoint'] + __all__ = ["python_code_with_activation_checkpoint"] def _gen_saved_tensors_hooks(): @@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]): Find the checkpoint regions given a list of consecutive nodes. The outputs will be list of tuples, each tuple is in the form of (start_index, end_index). """ - ckpt_nodes = [] ckpt_regions = [] start = -1 end = -1 current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + act_ckpt_label = node.meta["activation_checkpoint"] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]): current_region = act_ckpt_label start = idx end = -1 - elif current_region is not None and not 'activation_checkpoint' in node.meta: + elif current_region is not None and not "activation_checkpoint" in node.meta: # used to check the case below # node ckpt states = [ckpt, ckpt, non-ckpt] end = idx - 1 @@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): - act_offload_label = node.meta['activation_offload'] + if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable): + act_offload_label = node.meta["activation_offload"] if current_region == None: current_region = act_offload_label @@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, check_idx: int) -> bool: @@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool: Returns: bool """ - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], list): - return node.meta['activation_checkpoint'][check_idx] == None + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], list): + return node.meta["activation_checkpoint"][check_idx] == None else: return False else: @@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], int): - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], int): + act_ckpt_label = node.meta["activation_checkpoint"] else: - act_ckpt_label = node.meta['activation_checkpoint'][check_idx] + act_ckpt_label = node.meta["activation_checkpoint"][check_idx] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: body: forward code, in recursive calls, this part will be checkpoint @@ -321,17 +318,17 @@ def emit_ckpt_func(body, inputs, outputs = _find_input_and_output_nodes(node_list) # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].meta['activation_checkpoint'], int): - label = node_list[0].meta['activation_checkpoint'] + if isinstance(node_list[0].meta["activation_checkpoint"], int): + label = node_list[0].meta["activation_checkpoint"] ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage += "\n" body.append(usage) @@ -340,12 +337,12 @@ def emit_ckpt_func(body, else: # label given by each layer, e.g. if you are currently at level [0, 1, 1] # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if level + 1 < len(node_list[0].meta['activation_checkpoint']): + if level + 1 < len(node_list[0].meta["activation_checkpoint"]): ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -358,38 +355,45 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, - delete_unused_value_func, level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod if within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # find the input and output var names for each region for idx, (start, end) in enumerate(ckpt_regions): - ckpt_node_list = node_list[start:end + 1] + ckpt_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) input_vars.append(inputs) output_vars.append(outputs) # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if idx in start_idx: label = start_idx.index(idx) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") within_ckpt_region = True if idx in offload_starts: @@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # NOTE: currently we separate body and ckpt_func definition if within_ckpt_region: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) elif within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate return statement label = end_idx.index(idx) return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n\n' + return_statement = f" {return_statement}\n\n" ckpt_func.append(return_statement) # we need to check if the checkpoint need to offload the input start_node_idx = start_idx[label] - if 'activation_offload' in node_list[start_node_idx].meta: - activation_offload = node_list[start_node_idx].meta['activation_offload'] + if "activation_offload" in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta["activation_offload"] else: activation_offload = False @@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if input_node.op != "placeholder": non_leaf_input = 1 for user in input_node.users: - if 'activation_checkpoint' in user.meta: - if user.meta['activation_checkpoint'] == label: + if "activation_checkpoint" in user.meta: + if user.meta["activation_checkpoint"] == label: if user.op == "call_module": if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace @@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate checkpoint function call in a new line usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) - usage += '\n' + usage += "\n" body.append(usage) within_ckpt_region = False @@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -629,7 +632,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -637,7 +640,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -662,16 +665,16 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -690,19 +693,18 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -728,90 +730,101 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) @@ -820,13 +833,13 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -837,11 +850,11 @@ def emit_node(node: Node, body): # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} @@ -861,7 +874,7 @@ def python_code_with_activation_checkpoint(self, root_module: str, namespace: _N wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -869,7 +882,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -894,12 +907,12 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) # This is a generic type, e.g. typing.List[torch.Tensor] - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) @@ -934,84 +947,94 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" if self._pytree_info is None: - body.append(f'return {repr(node.args[0])}') + body.append(f"return {repr(node.args[0])}") else: - body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)") return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) @@ -1020,33 +1043,34 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if self._pytree_info is not None: orig_args = self._pytree_info.orig_args - has_orig_self = (orig_args[0] == 'self') + has_orig_self = orig_args[0] == "self" if has_orig_self: - free_vars.insert(0, 'self') - if len(free_vars) > 0: # pytree has placeholders in it + free_vars.insert(0, "self") + if len(free_vars) > 0: # pytree has placeholders in it body.insert( 0, - f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n", + ) else: orig_args = free_vars if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" - ckpt_func = ''.join(ckpt_func) + ckpt_func = "".join(ckpt_func) # If the original function didn't have self as its first argument, we # would have added it. - if len(orig_args) == 0 or orig_args[0] != 'self': - orig_args.insert(0, 'self') - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + if len(orig_args) == 0 or orig_args[0] != "self": + orig_args.insert(0, "self") + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index ebb9975f27db..8429a9607f7a 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -1,32 +1,35 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn from torch.nn.modules.module import _addindent try: - from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen - from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall + from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen + from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + COLOGM = True except: from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule + COLOGM = False if COLOGM: class ColoGraphModule(GraphModule): - - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: Graph, - class_name: str = 'GraphModule', - ckpt_codegen: bool = True): + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ckpt_codegen: bool = True, + ): if ckpt_codegen: graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) @@ -60,7 +63,7 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -83,8 +86,8 @@ def recompile(self) -> PythonCode: # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -108,7 +111,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -125,7 +128,13 @@ def __init__(self): def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: safe_reprs = [ - nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, ] if type(module) in safe_reprs: return f"{module.__repr__()}" @@ -136,10 +145,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -156,19 +165,20 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) else: class ColoGraphModule(GraphModule): - - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"): super().__init__(root, graph, class_name) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 245ba5d776da..99c8faaa0cc6 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -1,8 +1,6 @@ import numpy as np import torch import tqdm -from torch.fx import symbolic_trace -from torch.fx.node import Node from colossalai.fx.passes.split_module import split_module @@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): accumulate_bwd_flop = 0 block_nodes = [] for node in gm.graph.nodes: - if 'block_split' in node.name: + if "block_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop accumulate_bwd_flop += node.bwd_flop if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: with gm.graph.inserting_after(node): - block_node = gm.graph.create_node('call_function', block_split) - setattr(block_node, 'fwd_flop', accumulate_fwd_flop) - setattr(block_node, 'bwd_flop', accumulate_bwd_flop) + block_node = gm.graph.create_node("call_function", block_split) + setattr(block_node, "fwd_flop", accumulate_fwd_flop) + setattr(block_node, "bwd_flop", accumulate_bwd_flop) accumulate_fwd_flop = 0 accumulate_bwd_flop = 0 block_nodes.append(block_node) @@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): def remove_blocks(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - if (node.op, node.target) == ('call_function', block_split): + if (node.op, node.target) == ("call_function", block_split): gm.graph.erase_node(node) @@ -55,8 +53,8 @@ def get_compute_costs(node_list): num_nodes = len(node_list) all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) - for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): - for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): + for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0): + for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False): selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] all_compute_cost[start, end] = sum(selected_flops) @@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost # record start node index for next stage in this partition f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) f[0, num_nodes] = 0 - for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks - for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): - for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): + for s in tqdm.tqdm( + range(1, num_stages + 1), desc="stage", position=2, leave=False + ): # pylint: disable=too-many-nested-blocks + for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False): + for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False): stage_cost = compute_costs[i, k - 1] new_cost = f[s - 1, k] + stage_cost - if (stage_cost <= max_compute_cost and new_cost < f[s, i]): + if stage_cost <= max_compute_cost and new_cost < f[s, i]: f[s, i] = new_cost f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) f_argmin[s, i] = k @@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche best_cost = np.inf best_solution = None last_max_compute_cost = 0.0 - gap = 1e6 # temporary magic number, unit: flops + gap = 1e6 # temporary magic number, unit: flops for max_compute_cost in tqdm.tqdm(max_compute_costs): # Pruning to reduce search space. @@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche if max_compute_cost - last_max_compute_cost < gap: continue - cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, - max_compute_cost) + cost, solution = do_dp_split_gpipe_impl( + len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost + ) if cost < best_cost: best_cost = cost @@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche # split_mode: # 'node': fx_node # 'block': many fx_nodes construct a block -def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): - assert mode in ['node', 'block'] +def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01): + assert mode in ["node", "block"] # nodes or blocks will be used in partition. node_list = [] - if mode == 'node': + if mode == "node": for node in gm.graph.nodes: node_list.append(node) - elif mode == 'block': + elif mode == "block": node_list = construct_blocks(gm, limit=block_limit) else: pass @@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) - for (_, next_start_node) in best_solution: + for _, next_start_node in best_solution: if pp_size <= 1: break node = node_list[next_start_node] with gm.graph.inserting_before(node): - split_node = gm.graph.create_node('call_function', pipe_split) + split_node = gm.graph.create_node("call_function", pipe_split) pp_size -= 1 # remove block node if possible - if mode == 'block': + if mode == "block": remove_blocks(gm) gm.recompile() @@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): # To use avgcompute_split_pass, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_fwd_flop = 0 @@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop if accumulate_fwd_flop >= partition_flop: @@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_flop = total_fwd_flop // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): if accumulate_num_node >= avg_num_node: accumulate_num_node = 0 pp_size -= 1 - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 # If the next node is output node, we will insert split annotation before # node to make sure there is at least one node in last partition. - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) if pp_size > 1: node_counter = 0 for node in mod_graph.nodes: if pp_size <= 1: break - if node.op == 'placeholder': + if node.op == "placeholder": continue elif node_counter == 0: node_counter += 1 @@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 node_counter = 0 with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_element_size = 0 @@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_node_size += node.node_size if accumulate_node_size >= partition_size: @@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_size = total_element_size // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): accumulate_layer_amount = 0 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -355,7 +356,7 @@ def split_callback(n: torch.fx.Node): for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 81ac64205528..5440a4eadbbf 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.fx @@ -85,10 +85,10 @@ def run_node(self, n: Node) -> Any: self._is_proped = True result, meta_info = super().run_node(n) - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -98,7 +98,7 @@ def run_node(self, n: Node) -> Any: # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -119,7 +119,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -138,7 +138,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -157,7 +157,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di return profile_function(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -175,7 +175,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_method(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -197,7 +197,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_module(submod, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -228,7 +228,7 @@ def propagate(self, *args): """ return self.run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -238,9 +238,11 @@ def summary(self, unit: str = 'MB') -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -249,10 +251,10 @@ def summary(self, unit: str = 'MB') -> str: def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -261,30 +263,32 @@ def time_repr(time: float): for node in self.module.graph.nodes: node: Node - node_summaries.append([ - node.op, - str(node), - time_repr(node.meta['fwd_time']), - time_repr(node.meta['bwd_time']), - node.meta['save_fwd_in'], - mem_repr(node.meta['fwd_mem_out']), - mem_repr(node.meta['fwd_mem_tmp']), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + time_repr(node.meta["fwd_time"]), + time_repr(node.meta["bwd_time"]), + node.meta["save_fwd_in"], + mem_repr(node.meta["fwd_mem_out"]), + mem_repr(node.meta["fwd_mem_tmp"]), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward time', - 'Backward time', - 'SAVE_FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward time", + "Backward time", + "SAVE_FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py index 4571bd93a790..3d032a27db63 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -1,14 +1,11 @@ -import torch -from typing import List -from torch.fx import symbolic_trace -from torch.fx.node import Node -from colossalai.fx.passes.split_module import split_module -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec import builtins import operator -from copy import deepcopy +from typing import List + +import torch + +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec def apply(*args, **kwargs): @@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi origin_node_sharding_spec_dict = {} for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec # apply the sharding spec of parameters for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) - setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) + setattr(target_module.weight, "sharding_spec", origin_sharding_spec) target_weight_sharding_spec = node.best_strategy.input_shardings[1] target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) apply(target_module.weight, target_weight_sharding_spec) @@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") break return sharding_spec_convert_dict, origin_node_sharding_spec_dict @@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): # add shape consistency apply function into graph for node in nodes: - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue with mod_graph.inserting_after(node): - origin_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(origin_dict_node, node_to_index_dict[node])) + origin_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_after(origin_spec_node): - set_sharding_spec_node = mod_graph.create_node('call_function', - builtins.setattr, - args=(node, 'sharding_spec', origin_spec_node)) + set_sharding_spec_node = mod_graph.create_node( + "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node) + ) for user_node in node.strategies_vector.successor_nodes: node_index = user_node.strategies_vector.predecessor_nodes.index(node) with mod_graph.inserting_before(user_node): - input_specs_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_dict_node, node_to_index_dict[node])) + input_specs_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_before(user_node): - sharding_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_specs_node, node_index)) + sharding_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_specs_node, node_index) + ) with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) + shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node)) return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index ab203dfd7440..1720aa58da2b 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -109,13 +109,13 @@ def extract_tensor_meta(obj): return TensorMetadata(None, None, False, None, 0, False) tensor_meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta["tensor_meta"] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) - setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) - setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0))) + setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0)) + setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -125,7 +125,7 @@ def extract_tensor_meta(obj): # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -146,7 +146,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -165,7 +165,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -184,7 +184,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di return profile_function(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -202,7 +202,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_method(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -224,7 +224,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_module(submod)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -240,7 +240,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - if hasattr(args[0], '_tensor'): + if hasattr(args[0], "_tensor"): return args[0], GraphInfo(fwd_in=[args[0]._tensor]) return args[0], GraphInfo(save_fwd_in=True) @@ -257,7 +257,7 @@ def propagate(self, *args): """ return super().run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -267,9 +267,11 @@ def summary(self, unit: str = 'MB') -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -278,10 +280,10 @@ def summary(self, unit: str = 'MB') -> str: def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -292,35 +294,37 @@ def flops_repr(flop: int) -> str: for node in self.module.graph.nodes: node: Node accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node) - node_summaries.append([ - node.op, - str(node), - flops_repr(node.meta['fwd_flop']), - flops_repr(node.meta['bwd_flop']), - mem_repr(accumulate_size), - mem_repr(calculate_fwd_in(node)), - mem_repr(calculate_fwd_out(node)), - mem_repr(calculate_fwd_tmp(node)), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + flops_repr(node.meta["fwd_flop"]), + flops_repr(node.meta["bwd_flop"]), + mem_repr(accumulate_size), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward FLOPs', - 'Backward FLOPs', - 'Accumulated Memory', - 'FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward FLOPs", + "Backward FLOPs", + "Accumulated Memory", + "FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: @@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: Returns: torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. """ - device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") interp = MetaInfoProp(gm.to(device)) if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) interp.propagate(*args, **kwargs) if verbose: interp.summary(unit) - gm.to('cpu') + gm.to("cpu") del interp return gm diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index efdd34a01fe0..73379f73689c 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -5,7 +5,6 @@ from packaging import version from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule -from torch.fx.node import Node from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split from colossalai.fx.passes.meta_info_prop import TensorMetadata @@ -13,9 +12,9 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): - ''' + """ This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. - ''' + """ mod_graph = gm.graph valid_children_size = 0 valid_children = [] @@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti part_index += 1 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): - ''' + """ This pass will be used in gpt2 test, only a part of changes may be added into split_with_split_nodes_pass, and it will be deprecated in future. - ''' + """ part_idx = 0 def eliminate_unused_placeholders(gm): for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": if not len(node.users): gm.graph.erase_node(node) gm.recompile() return gm def refill_outputs_and_placeholders(gm, next_partition_placeholders): - ''' + """ This method is used to eliminate the outputs in previous partition which is unused in next partition. In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it to partition 1 and partition 2. However, in single direction linked list, we need to do so. - ''' + """ output_type = None output_args = [] non_output_list = [] new_placeholder_list = [] for node in gm.graph.nodes: - if node.op == 'output': + if node.op == "output": if isinstance(node.args[0], (tuple, list)): output_type = node.args[0].__class__ output_args.extend([n.name for n in node.args[0]]) @@ -114,7 +113,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders): continue for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": new_placeholder_list.append(node.name) if output_type is not None: gm.graph.output(output_type(output_args)) @@ -125,7 +124,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders): def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -134,7 +133,7 @@ def split_callback(n: torch.fx.Node): for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) @@ -200,13 +199,12 @@ def _gen_all_ancestors_set(node): _gen_all_ancestors_set(node) for n in list(all_ancestors): - if n.op != 'placeholder' and n._fx_partition > partition_name: + if n.op != "placeholder" and n._fx_partition > partition_name: n._fx_partition = partition_name - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: # if 'tensor_meta' in def_node.meta: # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): @@ -237,7 +235,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": # partition_name = str(split_callback(node)) # def _set_output_args_partition(n, partition_name): # n._fx_partition = partition_name @@ -252,12 +250,12 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, partitions[partition_name] = partition = Partition(partition_name) partition.node_names.append(node.name) - origin_partition_name = getattr(node, '_fx_partition', None) + origin_partition_name = getattr(node, "_fx_partition", None) if origin_partition_name is None: node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -287,7 +285,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -295,26 +293,24 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs, - name=node.name) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -323,14 +319,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.name, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -344,13 +340,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -358,14 +355,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index ccbab0c38a29..be8261f2a3f4 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -9,8 +9,19 @@ ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ - torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout + torch.add, + operator.add, + torch.abs, + torch.cos, + torch.exp, + torch.mul, + operator.mul, + operator.floordiv, + operator.truediv, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, ] @@ -72,7 +83,7 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # traverse the graph to look for consecutive linear layers is_linear_module = False - if node.op == 'call_module': + if node.op == "call_module": # look for the linear layer module = node.graph.owning_module.get_submodule(node.target) if isinstance(module, nn.Linear): @@ -82,31 +93,31 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # it means the first linear has been found and the current module # is the second linear # set the current linear module to be row-sharded - annotation_record['row'] = module + annotation_record["row"] = module for shard_type, module in annotation_record.items(): # add row sharding spec - if shard_type == 'row': + if shard_type == "row": dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) comp_spec = ComputeSpec(ComputePattern.TP1D) - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', dist_spec) - setattr(module.weight, 'comp_spec', comp_spec) - elif shard_type == 'col': + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", dist_spec) + setattr(module.weight, "comp_spec", comp_spec) + elif shard_type == "col": weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec.output_replicate = False - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', weight_dist_spec) - setattr(module.weight, 'comp_spec', weight_comp_spec) + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", weight_dist_spec) + setattr(module.weight, "comp_spec", weight_comp_spec) if module.bias is not None: bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec.output_replicate = False - setattr(module.bias, 'pg', process_group) - setattr(module.bias, 'dist_spec', bias_dist_spec) - setattr(module.bias, 'comp_spec', bias_comp_spec) + setattr(module.bias, "pg", process_group) + setattr(module.bias, "dist_spec", bias_dist_spec) + setattr(module.bias, "comp_spec", bias_comp_spec) start_tracking = False annotation_record.clear() else: @@ -114,16 +125,16 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # it means the current layer is the first linear # set the linear layer to be col-sharded start_tracking = True - annotation_record['col'] = module + annotation_record["col"] = module if start_tracking and not is_linear_module: # check against the white list # if non-element wise op is found, we reset the tracking - if node.op == 'call_module': + if node.op == "call_module": module = node.graph.owning_module.get_submodule(node.target) if module.__class__ not in ELEMENTWISE_MODULE_OP: start_tracking = False - elif node.op == 'call_function' or node.op == 'call_method': + elif node.op == "call_function" or node.op == "call_method": if node.target not in ELEMENTWISE_FUNC_OP: start_tracking = False elif len(node.users.keys()) > 1: diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 61ed037ab7a1..67a2432595d6 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -25,12 +25,14 @@ def __init__(self, name: str): self.targets: Dict[str, Any] = {} def __repr__(self) -> str: - return f"name: {self.name},\n" \ - f" nodes: {self.node_names},\n" \ - f" inputs: {self.inputs},\n" \ - f" outputs: {self.outputs},\n" \ - f" partitions dependent on: {self.partitions_dependent_on},\n" \ + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions dependent on: {self.partitions_dependent_on},\n" f" partition dependents: {self.partition_dependents}" + ) # Creates subgraphs out of main graph @@ -117,10 +119,9 @@ def forward(self, x, y): partitions: Dict[str, Partition] = {} orig_nodes: Dict[str, torch.fx.node.Node] = {} - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: if def_partition_name is not None: def_partition = partitions[def_partition_name] @@ -134,7 +135,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -161,7 +162,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": if merge_output: torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) else: @@ -178,7 +179,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -208,7 +209,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -216,25 +217,24 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -243,14 +243,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.target, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -264,13 +264,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -278,15 +279,15 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 for partition_name in sorted_partitions: partition = partitions[partition_name] diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index bb4f3cd6a490..c51f49a30e8a 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,7 +1,9 @@ -import torch from typing import Dict -from torch.fx.node import Node, map_arg + +import torch from torch.fx.graph import Graph +from torch.fx.node import Node, map_arg + def get_comm_size(prev_partition, next_partition): """ @@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition): map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) for n in input_nodes: if n.name in parent_node_names and n not in visited_nodes: - comm_size += n.meta['tensor_meta'].numel + comm_size += n.meta["tensor_meta"].numel visited_nodes.add(n) return comm_size @@ -36,12 +38,12 @@ def get_leaf(graph: Graph): """ input_nodes: Dict[Node, None] = {} for node in graph.nodes: - if node.op == 'output': + if node.op == "output": map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) placeholder_nodes = [] for node in input_nodes.keys(): - if node.op == 'placeholder': + if node.op == "placeholder": placeholder_nodes.append(node) for node in placeholder_nodes: input_nodes.pop(node) @@ -60,13 +62,13 @@ def get_top(graph: Graph): """ top_node_list = set() for node in graph.nodes: - if node.op == 'output': + if node.op == "output": continue is_top = False def _get_top(node): nonlocal is_top - if node.op == 'placeholder': + if node.op == "placeholder": is_top = True map_arg(node.args, lambda n: _get_top(n)) @@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node): def get_all_consumers(graph: Graph, node: Node): """ Given a graph and a node of this graph, return all consumers of the node. - + Returns: List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. """ @@ -120,7 +122,7 @@ def forward(self, x): for node in gm.graph.nodes: if hasattr(node, 'bfs_level'): print(node.name, node.bfs_level) - + Output: graph(): %x : [#users=2] = placeholder[target=x] @@ -148,7 +150,7 @@ def forward(self, x): while nodes_to_process: new_process_list = [] for node in nodes_to_process: - if node.op == 'output': + if node.op == "output": continue node.bfs_level = current_level new_process_list.extend(get_all_consumers(graph, node)) @@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module: torch.nn.Module: the module associated with the given node """ - assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' - assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + assert ( + node.graph.owning_module is not None + ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object" + assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}" module = node.graph.owning_module.get_submodule(node.target) return module - diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 8bcbde0eb23b..89dd2b3df617 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -12,7 +12,16 @@ ) from .tensor import MetaTensor else: - from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + from .experimental import ( + meta_profiler_function, + meta_profiler_module, + profile_function, + profile_method, + profile_module, + calculate_fwd_in, + calculate_fwd_tmp, + calculate_fwd_out, + ) from .dataflow import GraphInfo from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py index 5763a46dc83f..fad9bb272bff 100644 --- a/colossalai/fx/profiler/constants.py +++ b/colossalai/fx/profiler/constants.py @@ -1,6 +1,6 @@ import torch -__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] +__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"] aten = torch.ops.aten diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index a5e8880322b8..05f9b50ce575 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from enum import Enum -from functools import partial from typing import Dict, List from torch.fx import Graph, Node @@ -69,8 +68,8 @@ class GraphInfo: def is_phase(n: Node, phase: Phase) -> bool: - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == phase + assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!" + return n.meta["phase"] == phase @compatibility(is_backward_compatible=False) @@ -103,9 +102,9 @@ def _peak_memory(deps: Dict[Node, int]): peak_mem = 0 for k, v in deps.items(): if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): - peak_mem += activation_size(k.meta['saved_tensor']) - if v <= float('-inf') and is_phase(k, Phase.FORWARD): - peak_mem -= activation_size(k.meta['saved_tensor']) + peak_mem += activation_size(k.meta["saved_tensor"]) + if v <= float("-inf") and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta["saved_tensor"]) return peak_mem # deps is used to track all the memory dependencies of the graph. @@ -123,19 +122,19 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # the node, `fwd_mem_tmp` can be freed. if is_phase(n, Phase.PLACEHOLDER): - graph_info.fwd_in += n.meta['saved_tensor'] + graph_info.fwd_in += n.meta["saved_tensor"] if is_phase(n, Phase.FORWARD): - graph_info.fwd_tmp += n.meta['saved_tensor'] + graph_info.fwd_tmp += n.meta["saved_tensor"] elif is_phase(n, Phase.BACKWARD): if len(n.users): graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) else: # TODO: some of the bwd_mem_out might be model parameters. # basically a backward node without user is a `grad_out` node - graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) + graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"]) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 if deps[input_n] <= 0: - deps[input_n] = float('-inf') + deps[input_n] = float("-inf") return graph_info diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py index 57ff3fd91299..02758e7643af 100644 --- a/colossalai/fx/profiler/experimental/constants.py +++ b/colossalai/fx/profiler/experimental/constants.py @@ -2,7 +2,7 @@ import torch -__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] +__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"] # TODO fill out the inplace ops INPLACE_OPS = [ @@ -20,25 +20,25 @@ # TODO: list all call_methods that are inplace here INPLACE_METHOD = [ - 'transpose', - 'permute', + "transpose", + "permute", # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', + "reshape", + "dim", + "flatten", + "size", + "view", + "unsqueeze", + "to", + "type", + "flatten", ] # TODO: list all call_methods that are not inplace here NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', + "chunk", + "contiguous", + "expand", + "mean", + "split", ] diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 5c545260e72b..d890fdb66fc2 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -9,7 +9,7 @@ from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .registry import meta_profiler_function, meta_profiler_module -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # this is for compatibility use @@ -42,6 +42,7 @@ class GraphInfo: bwd_mem_tmp (int): See the above illustration. bwd_mem_out (int): See the above illustration. """ + fwd_flop: int = 0 bwd_flop: int = 0 fwd_mem_in: int = 0 @@ -50,8 +51,7 @@ class GraphInfo: bwd_mem_out: int = 0 -CALL_FUNCTION_MSG = \ -""" +CALL_FUNCTION_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_function @meta_profiler_function.register(YOUR_FUNCTION) @@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: macs = ... return flops, macs """ -CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' -CALL_MODULE_MSG = \ -""" +CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}" +CALL_MODULE_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_module @meta_profiler_module.register(YOUR_MODULE) @@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target') -> Callable: +def profile_function(target: "Target") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_function.has(target) or meta_profiler_function.has( - target.__name__), CALL_FUNCTION_MSG.format(target) + target.__name__ + ), CALL_FUNCTION_MSG.format(target) fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if target not in INPLACE_OPS and not kwargs.get('inplace', False): + if target not in INPLACE_OPS and not kwargs.get("inplace", False): fwd_out = activation_size(out) if meta_profiler_function.has(target): profiler = meta_profiler_function.get(target) @@ -112,7 +112,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target') -> Callable: +def profile_method(target: "Target") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -126,11 +126,12 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: self_obj, *args_tail = args # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' + assert isinstance(target, str), f"{target} instance is not str." out = getattr(self_obj, target)(*args_tail, **kwargs) assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( - target, INPLACE_METHOD, NON_INPLACE_METHOD) + target, INPLACE_METHOD, NON_INPLACE_METHOD + ) # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) @@ -161,7 +162,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py index a43aef063e19..c518ec28da41 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py index 8d1c8a8c6877..f1b9bb97c6c6 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py +++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py @@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other): @meta_profiler_function.register(torch.sub) @meta_profiler_function.register(torch.mul) @meta_profiler_function.register(torch.floor_divide) -@meta_profiler_function.register('add') # for built-in op + -@meta_profiler_function.register('iadd') # for built-in op += -@meta_profiler_function.register('eq') # for built-in op = -@meta_profiler_function.register('sub') # for built-in op - -@meta_profiler_function.register('isub') # for built-in op -= -@meta_profiler_function.register('mul') # for built-in op * -@meta_profiler_function.register('imul') # for built-in op *= -@meta_profiler_function.register('floordiv') # for built-in op // -@meta_profiler_function.register('ifloordiv') # for built-in op //= +@meta_profiler_function.register("add") # for built-in op + +@meta_profiler_function.register("iadd") # for built-in op += +@meta_profiler_function.register("eq") # for built-in op = +@meta_profiler_function.register("sub") # for built-in op - +@meta_profiler_function.register("isub") # for built-in op -= +@meta_profiler_function.register("mul") # for built-in op * +@meta_profiler_function.register("imul") # for built-in op *= +@meta_profiler_function.register("floordiv") # for built-in op // +@meta_profiler_function.register("ifloordiv") # for built-in op //= def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: return _elementwise_flops_compute(input, other) @@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N @meta_profiler_function.register(torch.matmul) -@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register("matmul") # for built-in op @ @meta_profiler_function.register(torch.Tensor.matmul) def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: macs = reduce(operator.mul, input.shape) * other.shape[-1] @@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T @meta_profiler_function.register(torch.var_mean) -def torch_var_mean(input: torch.Tensor, - dim: Union[int, Tuple[int, ...]], - unbiased: Optional[bool] = True, - keepdim: Optional[bool] = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: - assert out is None, 'saving to out is not supported yet' +def torch_var_mean( + input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None, +) -> Tuple[int, int]: + assert out is None, "saving to out is not supported yet" flops = input.numel() * 3 macs = 0 return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py index d6e43d781b8b..1d362015fc8b 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -1,5 +1,7 @@ -import torch from typing import Optional + +import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py index 01fe4c871370..ecc578d61b91 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py index c4ea508d70f8..2ad029eda039 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -1,5 +1,7 @@ from typing import List, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -21,11 +23,13 @@ def torch_nn_func_instancenorm( @meta_profiler_function.register(torch.nn.functional.group_norm) -def torch_nn_func_groupnorm(input: torch.Tensor, - num_groups: int, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - eps: float = 1e-5) -> Tuple[int, int]: +def torch_nn_func_groupnorm( + input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: has_affine = weight is not None flops = input.numel() * (5 if has_affine else 4) macs = 0 diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py index a639f5ee83c1..c91deab906d4 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -1,5 +1,7 @@ -from typing import Tuple, Union +from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py index 1e8561206ba0..58c9889ad98e 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -1,6 +1,6 @@ import operator from typing import Any, Tuple -import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py index abdd7ad565ba..67e90fb69acd 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -1,7 +1,9 @@ -from functools import reduce import operator +from functools import reduce from typing import Any, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: @meta_profiler_function.register(torch.max) -def torch_max(input: torch.Tensor, - dim: int = None, - keepdim: bool = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: +def torch_max( + input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None +) -> Tuple[int, int]: macs = 0 - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: shape = list(input.shape) shape.pop(int(dim)) diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py index 2ebf514ad269..ae065e0c7c17 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py index 8daf74b232bf..dfaee75e0432 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/attention.py +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -1,19 +1,23 @@ from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module # TODO: This is hard to compute memory cost @meta_profiler_module.register(torch.nn.MultiheadAttention) -def torch_nn_msa(self: torch.nn.MultiheadAttention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[torch.Tensor] = None, - average_attn_weights: bool = True) -> Tuple[int, int]: - if getattr(self, 'batch_first', False): +def torch_nn_msa( + self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[int, int]: + if getattr(self, "batch_first", False): batch_size = query.shape[0] len_idx = 1 else: @@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, flops += qlen * qdim # Initial projections - flops += 2 * ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW - macs += ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW if self.in_proj_bias is not None: flops += (qlen + klen + vlen) * qdim @@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, v_head_dim = vdim // num_heads head_flops = ( - 2 * (qlen * klen * qk_head_dim) # QK^T - + (qlen * klen) # softmax - + 2 * (qlen * klen * v_head_dim) # AV + 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV ) - head_macs = ((qlen * klen * qk_head_dim) # QK^T - + 2 * (qlen * klen * v_head_dim) # AV - ) + head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV flops += num_heads * head_flops macs += num_heads * head_flops diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py index a4c15b91e611..90e494c77f5b 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py +++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py @@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, @@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups num_elem = reduce( operator.mul, input.shape - ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py index 417e0ed46863..7361239eb1bd 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py index e1ffb6f244d2..71fed3196c13 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py index 49e5e6fa5384..5a64e44947b7 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py @@ -16,8 +16,12 @@ @meta_profiler_module.register(torch.nn.BatchNorm1d) @meta_profiler_module.register(torch.nn.BatchNorm2d) @meta_profiler_module.register(torch.nn.BatchNorm3d) -def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: +def torch_nn_normalize( + self: Union[ + torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d + ], + input: torch.Tensor, +) -> Tuple[int, int]: # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 has_affine = self.weight is not None if self.training: @@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch try: import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py index e429ac3eea28..b3b630b2dee9 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py index 6e733d6da915..8a4c828dbd27 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -1,12 +1,15 @@ -from functools import reduce import operator +from functools import reduce +from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union -def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, - w_hh: torch.Tensor) -> Tuple[int, int]: +def _rnn_flops( + flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor +) -> Tuple[int, int]: # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py # matrix matrix mult ih state and internal state @@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch flops = 0 macs = 0 for i in range(self.num_layers): - w_ih = self.__getattr__('weight_ih_l' + str(i)) - w_hh = self.__getattr__('weight_hh_l' + str(i)) + w_ih = self.__getattr__("weight_ih_l" + str(i)) + w_hh = self.__getattr__("weight_hh_l" + str(i)) flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l' + str(i)) - b_hh = self.__getattr__('bias_hh_l' + str(i)) + b_ih = self.__getattr__("bias_ih_l" + str(i)) + b_hh = self.__getattr__("bias_hh_l" + str(i)) flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= reduce(operator.mul, input.shape[:2]) macs *= reduce(operator.mul, input.shape[:2]) @@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: flops = 0 macs = 0 - w_ih = self.__getattr__('weight_ih_l') - w_hh = self.__getattr__('weight_hh_l') + w_ih = self.__getattr__("weight_ih_l") + w_hh = self.__getattr__("weight_hh_l") flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l') - b_hh = self.__getattr__('bias_hh_l') + b_ih = self.__getattr__("bias_ih_l") + b_hh = self.__getattr__("bias_hh_l") flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= input.shape[0] macs *= input.shape[0] diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py index d3aed874eb10..06be25246a71 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -1,7 +1,8 @@ -import operator +from typing import Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union @meta_profiler_module.register(torch.nn.Flatten) diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py index 7d73bce321e4..d47129cd2978 100644 --- a/colossalai/fx/profiler/experimental/registry.py +++ b/colossalai/fx/profiler/experimental/registry.py @@ -1,11 +1,9 @@ class ProfilerRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,5 +19,5 @@ def has(self, source): return source in self.store -meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') -meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') +meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile") +meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile") diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py index 1e53ed0bf8ec..90e8c3b7cfe4 100644 --- a/colossalai/fx/profiler/experimental/shard_utils.py +++ b/colossalai/fx/profiler/experimental/shard_utils.py @@ -1,8 +1,6 @@ # for PyTorch 1.11 compatibility uses -from typing import Dict, List, Tuple, Union -import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from ..._compatibility import compatibility @@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool: Returns: save_fwd_in (bool): the result of `save_fwd_in` """ - return n.meta['save_fwd_in'] + return n.meta["save_fwd_in"] @compatibility(is_backward_compatible=True) @@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int: Returns: fwd_out (int): the result of `fwd_out` """ - return n.meta['fwd_mem_out'] + return n.meta["fwd_mem_out"] diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py index 6ccbcb01cdc1..e8eb5f25cb6c 100644 --- a/colossalai/fx/profiler/memory_utils.py +++ b/colossalai/fx/profiler/memory_utils.py @@ -1,11 +1,11 @@ from typing import Dict, List, Tuple, Union import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta -__all__ = ['activation_size', 'parameter_size', 'is_inplace'] +__all__ = ["activation_size", "parameter_size", "is_inplace"] @compatibility(is_backward_compatible=True) @@ -63,6 +63,7 @@ def is_inplace(n: Node): inplace = n.kwargs.get("inplace", False) if is_compatible_with_meta(): from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: inplace = True elif n.op == "call_module": diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index ba090a2ec51b..8fae0f2ecb45 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -173,8 +173,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -218,15 +221,16 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse( + "2.0.0" +): flop_mapping = { - # gemm, gemv and dot + # gemm, gemv and dot aten.mm.default: matmul_flop_jit, aten.mv.default: matmul_flop_jit, aten.dot.default: matmul_flop_jit, @@ -234,13 +238,11 @@ def zero_flop_jit(*args): aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, aten.baddbmm.default: baddbmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, @@ -249,8 +251,7 @@ def zero_flop_jit(*args): aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_group_norm.default: norm_flop_counter(2, 0), aten.native_group_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), @@ -275,7 +276,7 @@ def zero_flop_jit(*args): } elementwise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -296,8 +297,7 @@ def zero_flop_jit(*args): aten.exp.default, aten.sin.default, aten.cos.default, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -320,8 +320,7 @@ def zero_flop_jit(*args): aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, ] @@ -362,7 +361,7 @@ def zero_flop_jit(*args): aten.zero_.default, aten.zeros_like.default, aten.fill_.Scalar, - aten.stack.default + aten.stack.default, ] # yapf: disable for op in zero_flop_aten: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index c87cd4321d31..97e70db6290e 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -15,7 +15,7 @@ from .opcount import flop_mapping from .tensor import MetaTensor -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # super-dainiu: this cache should be global, otherwise it cannot # track duplicated tensors between nodes @@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G # backward is executed. # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): - _node: Node = None def __repr__(self): @@ -186,24 +185,24 @@ def __repr__(self): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) - node = subgraph.create_node('call_function', func, args_node, kwargs_node) + node = subgraph.create_node("call_function", func, args_node, kwargs_node) out = super().__torch_dispatch__(func, types, args, kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) - node.meta['phase'] = phase + node.meta["phase"] = phase # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during # `Phase.FORWARD` if phase == Phase.FORWARD: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: - node.meta['phase'] = Phase.PLACEHOLDER + node.meta["phase"] = Phase.PLACEHOLDER # TODO(yby): specify `saved_tensors` for backward memory estimation - node.meta['saved_tensor'] = [] + node.meta["saved_tensor"] = [] if phase == Phase.BACKWARD: - node.meta['saved_tensor'] = normalize_tuple(out) + node.meta["saved_tensor"] = normalize_tuple(out) def wrap(x): if isinstance(x, MetaTensor): @@ -219,11 +218,14 @@ def wrap(x): x = FlopTensor(x) if is_autogradable(x): x.requires_grad_(True) - x._node = subgraph.create_node('placeholder', - 'placeholder', (subgraph._root,), - name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['phase'] = Phase.PLACEHOLDER - x._node.meta['saved_tensor'] = [] + x._node = subgraph.create_node( + "placeholder", + "placeholder", + (subgraph._root,), + name=subgraph._graph_namespace.create_name("input", x._tensor), + ) + x._node.meta["phase"] = Phase.PLACEHOLDER + x._node.meta["saved_tensor"] = [] return x # Basically, we need to detach the args and kwargs from the outer graph. @@ -235,7 +237,7 @@ def pack(x): if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: tensor = x._tensor.detach() tensor.data_ptr = x._tensor.data_ptr - x._node.meta['saved_tensor'] += [tensor] + x._node.meta["saved_tensor"] += [tensor] if not do_not_cache: cache.add(x._tensor.data_ptr()) return x @@ -284,7 +286,7 @@ def unwrap(x): @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target', device: str = 'meta') -> Callable: +def profile_function(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # find the grad for parameter in args and kwargs param_size = 0 @@ -316,18 +317,18 @@ def get_param_size(x): # still run the profiling but discard some results regarding `target` global do_not_cache - inplace = kwargs.get('inplace', False) + inplace = kwargs.get("inplace", False) if target in OUTPUT_SAVED_OPS: do_not_cache = True if inplace: do_not_cache = True - kwargs['inplace'] = False - if device == 'meta': + kwargs["inplace"] = False + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: - kwargs['inplace'] = True + kwargs["inplace"] = True meta.bwd_mem_tmp = 0 meta.bwd_mem_out = 0 do_not_cache = False @@ -341,7 +342,7 @@ def get_param_size(x): @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target', device: str = 'meta') -> Callable: +def profile_method(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' - if device == 'meta': + assert isinstance(target, str), f"{target} instance is not str." + if device == "meta": out, meta = _profile_meta(target, *args, **kwargs) else: out, meta = _profile_concrete(target, *args, **kwargs) @@ -360,7 +361,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @compatibility(is_backward_compatible=True) -def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: +def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. @@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # calculate parameter size param_size = parameter_size(module) @@ -384,13 +384,13 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # still run the profiling but discard some results regarding `module`. global do_not_cache - inplace = getattr(module, 'inplace', False) + inplace = getattr(module, "inplace", False) if type(module) in OUTPUT_SAVED_MOD: do_not_cache = True if inplace: do_not_cache = True module.inplace = False - if device == 'meta': + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py index 34feefb4336a..75b7c814f05f 100644 --- a/colossalai/fx/profiler/shard_utils.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -59,9 +59,9 @@ def forward(self, input_2): Returns: bool: Whether the node is a ReLU-like node """ - if n.op == 'call_function': + if n.op == "call_function": return n.target in OUTPUT_SAVED_OPS - elif n.op == 'call_module': + elif n.op == "call_module": return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return False diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 2ee5e5c47750..7c14b48bdaa1 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,13 +1,13 @@ import uuid import torch -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from .._compatibility import compatibility from .constants import ALIAS_ATEN -__all__ = ['MetaTensor'] +__all__ = ["MetaTensor"] def set_data_ptr(x): @@ -43,12 +43,13 @@ def __new__(cls, elem, fake_device=None): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=elem.requires_grad) # deceive the frontend for aten selections + device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=elem.requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` set_data_ptr(r._tensor) return r @@ -69,15 +70,15 @@ def unwrap(x): x = x._tensor elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) @@ -93,7 +94,7 @@ def wrap(x): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) @@ -120,18 +121,18 @@ def replace(x): nonlocal fake_device if isinstance(x, str) or isinstance(x, _device): fake_device = x - return 'meta' + return "meta" return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, fake_device=fake_device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 7317072c6298..887832223fd6 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,12 +1,11 @@ -import operator -from typing import Any, List, Union +from typing import Any import torch -from torch.fx.proxy import Attribute, Proxy +from torch.fx.proxy import Proxy from colossalai.fx.tracer.meta_patch import meta_patched_function -__all__ = ['ColoProxy'] +__all__ = ["ColoProxy"] class ColoProxy(Proxy): @@ -39,11 +38,12 @@ def has_meta_data(self): return self._meta_data is not None def _assert_meta_data_is_tensor(self): - assert torch.is_tensor( - self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + assert ( + torch.is_tensor(self._meta_data) and self._meta_data.is_meta + ), f"Meta data is not a meta tensor for {self.node.name}" def _assert_has_meta_data(self): - assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' + assert self._meta_data is not None, f"Meta data is not set for {self.node.name}" def __len__(self): self._assert_has_meta_data() @@ -62,7 +62,6 @@ def __bool__(self): return self.meta_data def __getattr__(self, k): - return ColoAttribute(self, k) def __contains__(self, key): @@ -92,7 +91,6 @@ def _convert(val): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str): self.root = root self.attr = attr diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 1c5abb81d271..63a7bab654d5 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -39,7 +39,7 @@ class MetaProxy(torch.Tensor): _tensor: torch.Tensor _node: Node - __slots__ = ['_tensor', '_node'] + __slots__ = ["_tensor", "_node"] @staticmethod def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): @@ -51,22 +51,22 @@ def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): dtype=tensor.dtype, layout=tensor.layout, device=fake_device if fake_device is not None else tensor.device, - requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + requires_grad=tensor.requires_grad, + ) # deceive the frontend for aten selections r._tensor = tensor if placeholder: if name is None: - name = 'input' - r._node = graph.create_node('placeholder', - 'placeholder', (graph._root,), - name=namespace.create_name(name, tensor)) + name = "input" + r._node = graph.create_node( + "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor) + ) # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) return r @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(x): nonlocal fake_device if isinstance(x, MetaProxy): @@ -75,21 +75,21 @@ def unwrap(x): # assert not isinstance(x, MetaProxy) elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x def get_node(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): - x = MetaProxy(x, placeholder=True, name='weight') - return x if not hasattr(x, '_node') else x._node + if isinstance(x, torch.Tensor) and not hasattr(x, "_node"): + x = MetaProxy(x, placeholder=True, name="weight") + return x if not hasattr(x, "_node") else x._node args_node = tree_map(get_node, args) kwargs_node = tree_map(get_node, kwargs) - node = graph.create_node('call_function', func, args_node, kwargs_node) + node = graph.create_node("call_function", func, args_node, kwargs_node) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -103,9 +103,12 @@ def wrap(x): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) - return MetaProxy( - x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + x = x.to(torch.device("meta")) + return ( + MetaProxy(x, fake_device=fake_device) + if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor") + else x + ) def set_node(x): x._node = node @@ -125,9 +128,12 @@ def wrap(x): for tensor in normalize_tuple(out): if is_autogradable(tensor) and tensor.requires_grad: - grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( - tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) - torch.autograd.backward(tensor, - MetaProxy(grad, fake_device=tensor.device, placeholder=True), - retain_graph=True) + grad = ( + torch.empty_like(tensor._tensor, device=torch.device("meta")) + if isinstance(tensor, MetaProxy) + else torch.empty_like(tensor, device=torch.device("meta")) + ) + torch.autograd.backward( + tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True + ) return graph diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index e160497a7444..9cf1961d45ff 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -2,10 +2,10 @@ import torch -from ..proxy import ColoAttribute, ColoProxy -from .meta_patch import meta_patched_function, meta_patched_module +from ..proxy import ColoProxy +from .meta_patch import meta_patched_function -__all__ = ['is_element_in_list', 'extract_meta'] +__all__ = ["is_element_in_list", "extract_meta"] def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): @@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): def extract_meta(*args, **kwargs): - def _convert(val): if isinstance(val, ColoProxy): return val.meta_data diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py index 859a19bf6241..84c09109877e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,13 +7,12 @@ @bias_addition_method.register(torch.Tensor.addbmm) @bias_addition_function.register(torch.addbmm) class Addbmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def create_non_bias_func_proxy(self, input_proxy, other_proxy): @@ -25,7 +21,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.bmm - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -35,10 +31,10 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): return non_bias_func_proxy def insert_sum_node(self, input_proxy, sum_dims=0): - ''' + """ This method is used to sum the input_proxy through the sum_dims. - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.sum node_args = (input_proxy, sum_dims) node_kwargs = {} @@ -55,15 +51,15 @@ def generate(self): sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] # doing the multiplication with beta if it exists(temp_2 = beta * input) beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) alpha_proxy = self.create_mul_node(alpha, sum_proxy) else: diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py index fe7d8d07aac9..d087b2913005 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,17 +7,16 @@ @bias_addition_method.register(torch.Tensor.addmm) @bias_addition_function.register(torch.addmm) class Addmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def transpose_other_operand_for_linear(self, other_proxy): - ''' + """ This method is used to transpose the other operand for linear function. For example: input = torch.rand(3, 4) @@ -30,8 +26,8 @@ def transpose_other_operand_for_linear(self, other_proxy): # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 # before we call the linear function. new_output = torch.linear(m1, m2.transpose(0, 1)) + input - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.transpose node_args = (other_proxy, 0, 1) node_kwargs = {} @@ -43,14 +39,14 @@ def generate(self): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) else: alpha_proxy = non_bias_linear_func_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py index 8a3786332c08..42178b7b786e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -29,7 +29,6 @@ def extract_kwargs_from_origin_func(self): to insert two more operator.mul nodes for the computation graph to compute the final result. """ - pass @abstractmethod def generate(self): @@ -50,7 +49,6 @@ def generate(self): %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) """ - pass def create_mul_node(self, input_proxy, coefficent): """ @@ -59,7 +57,7 @@ def create_mul_node(self, input_proxy, coefficent): Therefore, we need to use this method insert two more operator.mul nodes for the computation graph to compute the final result. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = operator.mul node_args = ( input_proxy, @@ -82,7 +80,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.nn.functional.linear - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -96,7 +94,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py index e11ec0a364f1..ed060a350739 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -1,6 +1,3 @@ -import operator - -import torch import torch.nn.functional as F from ...registry import bias_addition_function @@ -9,17 +6,16 @@ @bias_addition_function.register(F.linear) class Linear(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): - assert 'bias' in self.kwargs + assert "bias" in self.kwargs kwargs = {} - if 'bias' in self.kwargs: - kwargs['bias'] = self.kwargs['bias'] + if "bias" in self.kwargs: + kwargs["bias"] = self.kwargs["bias"] return kwargs def generate(self): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) kwargs = self.extract_kwargs_from_origin_func() - bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"]) return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py index 591485fdb1ca..19c0e21d7c17 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -27,8 +27,8 @@ def _create_weight_proxy(self): Note: this function will be invoked during module initializing, you should never call this function. """ - weight_node_kind = 'get_attr' - weight_node_target = self.target + '.weight' + weight_node_kind = "get_attr" + weight_node_target = self.target + ".weight" weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) return weight_proxy @@ -39,8 +39,8 @@ def _create_bias_proxy(self): Note: this function will be invoked during module initializing, you should never call this function. """ - bias_node_kind = 'get_attr' - bias_node_target = self.target + '.bias' + bias_node_kind = "get_attr" + bias_node_target = self.target + ".bias" bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) return bias_proxy @@ -54,14 +54,13 @@ def extract_kwargs_from_mod(self): considered during module initializing. However, we need to consider those attributes as kwargs in F.conv2d. """ - pass def create_non_bias_func_proxy(self, input_proxy=None): """ This method is used to create the non_bias_func proxy, the node created by this proxy will compute the main computation, such as convolution, with bias option banned. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func if input_proxy is None: input_proxy = self.args[0] @@ -75,7 +74,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) @@ -100,7 +99,6 @@ def generate(self): %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) """ - pass module_to_func_dict = { diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index 4b6c82a74f57..812a141c1eab 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -1,6 +1,5 @@ import torch -import torch.nn.functional as F -from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple +from torch.nn.modules.utils import _pair, _single, _triple from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -10,17 +9,16 @@ @bias_addition_module.register(torch.nn.Conv2d) @bias_addition_module.register(torch.nn.Conv3d) class BiasAdditionConv(BiasAdditionModule): - def extract_kwargs_from_mod(self): root = self.tracer.root conv_module = root.get_submodule(self.target) - kwarg_attributes = ['groups', 'dilation', 'stride'] + kwarg_attributes = ["groups", "dilation", "stride"] non_bias_kwargs = {} for attr_name in kwarg_attributes: if hasattr(conv_module, attr_name): non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) if conv_module.padding_mode != "zeros": - #TODO: non zeros mode requires some extra processing for input + # TODO: non zeros mode requires some extra processing for input conv_type = type(conv_module) if conv_type == "torch.nn.Conv1d": padding_element = _single(0) @@ -28,9 +26,9 @@ def extract_kwargs_from_mod(self): padding_element = _pair(0) elif conv_type == "torch.nn.Conv3d": padding_element = _triple(0) - non_bias_kwargs['padding'] = padding_element + non_bias_kwargs["padding"] = padding_element else: - non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + non_bias_kwargs["padding"] = getattr(conv_module, "padding") return non_bias_kwargs @@ -41,11 +39,12 @@ def create_bias_reshape_proxy(self, dimensions): """ bias_shape = [1] * (dimensions - 1) bias_shape[0] = -1 - bias_reshape_node_kind = 'call_method' - bias_reshape_node_target = 'view' + bias_reshape_node_kind = "call_method" + bias_reshape_node_target = "view" bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) - bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, - bias_reshape_node_args, {}) + bias_reshape_proxy = self.tracer.create_proxy( + bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {} + ) return bias_reshape_proxy def generate(self): diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py index f6f7b6ddab40..b397f009846c 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -7,7 +6,6 @@ @bias_addition_module.register(torch.nn.Linear) class BiasAdditionLinear(BiasAdditionModule): - def extract_kwargs_from_mod(self): return {} diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py index 22a67d1ceccc..e6e511b72fbb 100644 --- a/colossalai/fx/tracer/experimental.py +++ b/colossalai/fx/tracer/experimental.py @@ -1,4 +1,3 @@ -import enum import functools import inspect import operator @@ -10,7 +9,7 @@ from torch.utils._pytree import tree_map from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta -from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer._tracer_utils import is_element_in_list from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict from colossalai.fx.tracer.registry import ( bias_addition_function, @@ -24,31 +23,45 @@ from colossalai.fx.profiler import MetaTensor Target = Union[Callable[..., Any], str] -Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node',]] -_CScriptMethod = ['add', 'mul', 'sub', 'div'] +Argument = Optional[ + Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + "Node", + ] +] +_CScriptMethod = ["add", "mul", "sub", "div"] _TorchNewMethod = [ - "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", - "finfo" + "arange", + "zeros", + "zeros_like", + "ones", + "ones_like", + "full", + "full_like", + "empty", + "empty_like", + "eye", + "tensor", + "finfo", ] _TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] def _truncate_suffix(s: str): import re - return re.sub(r'_\d+$', '', s) + + return re.sub(r"_\d+$", "", s) def default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @compatibility(is_backward_compatible=False) class ColoProxy(Proxy): - def __init__(self, *args, data=None, **kwargs): super().__init__(*args, **kwargs) self._meta_data = data @@ -100,7 +113,7 @@ def __getattr__(self, k): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -125,29 +138,28 @@ def ndim(self): @property def device(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {}) proxy.meta_data = self.meta_data.device return proxy @property def dtype(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {}) proxy.meta_data = self.meta_data.dtype return proxy def to(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs}) def cpu(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs}) def cuda(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs}) @compatibility(is_backward_compatible=False) class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -160,11 +172,11 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" @@ -172,7 +184,6 @@ def __repr__(self): @compatibility(is_backward_compatible=False) class ColoTracer(Tracer): - def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self._disable_module_getattr = False @@ -184,24 +195,28 @@ def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): self.inside_torch_checkpoint_func = False self.act_ckpt_region_count = 0 - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root @@ -211,20 +226,21 @@ def create_proxy(self, proxy.meta_data = attr_itr finally: self._disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self._disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self._disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self._disable_module_getattr = True try: @@ -238,14 +254,15 @@ def create_node(self, *args, **kwargs) -> Node: if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -260,20 +277,19 @@ def trace(self, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names concrete_arg_names = set(concrete_args.keys()) - non_concrete_arg_names = sig_names - concrete_arg_names + sig_names - concrete_arg_names def _check_arg_name_valid(names): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -292,7 +308,6 @@ def trace_activation_checkpoint(self, enabled: bool): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): # signal that the current tracing occurs within activation checkpoint part @@ -305,7 +320,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -356,10 +372,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -370,8 +389,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy @@ -389,42 +409,41 @@ def symbolic_trace( if meta_args is not None: root.to(default_device()) wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) root.cpu() else: graph = Tracer().trace(root, concrete_args=concrete_args) else: from .tracer import ColoTracer as OrigColoTracer - graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=meta_args) + + graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=meta_args + ) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return ColoGraphModule(root, graph, name) @compatibility(is_backward_compatible=False) class _TorchTensorOverride(object): - def __init__(self, tracer: Tracer): self.overrides = {} self.tracer = tracer def __enter__(self): - def wrap_tensor_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.tracer._disable_module_getattr = True try: - proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + proxy = self.tracer.create_proxy("call_function", target, args, kwargs) finally: self.tracer._disable_module_getattr = False return proxy @@ -446,11 +465,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): setattr(torch, name, orig) -def meta_prop_pass(gm: ColoGraphModule, - root: torch.nn.Module, - meta_args: Optional[Dict[str, Any]] = None, - concrete_args: Optional[Dict[str, torch.Tensor]] = None): - +def meta_prop_pass( + gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, +): if meta_args is None: meta_args = {} @@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default for node in gm.graph.nodes: - node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, - node.kwargs) + node._meta_data = _meta_data_computing( + meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs + ) def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n - if kind == 'placeholder': + if kind == "placeholder": meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) - elif kind == 'get_attr': + elif kind == "get_attr": attr_itr = root atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) meta_out = attr_itr - elif kind == 'call_function': + elif kind == "call_function": meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': - if target == '__call__': + elif kind == "call_method": + if target == "__call__": meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_module': + meta_out = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) + elif kind == "call_module": mod = root.get_submodule(target) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) else: @@ -603,26 +623,30 @@ def wrap_fn(n): if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) else: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target.__name__)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) if bias_addition_method.has(method): function_to_substitute = method_to_func_dict[method] - handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_method.get(method)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_module": # if not hasattr(self, "orig_forward"): @@ -631,8 +655,9 @@ def wrap_fn(n): mod_type = type(mod) if bias_addition_module.has(mod_type) and mod.bias is not None: function_to_substitute = module_to_func_dict[mod_type] - handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_module.get(mod_type)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) if handle is not None: handle.generate() diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py index 12c42514895e..75d7b18a067c 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -5,4 +5,4 @@ @meta_patched_function.register(torch.nn.functional.relu) def torch_nn_func_relu(input, inplace=False): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 042b92c5847a..3475f22e3b19 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -4,7 +4,7 @@ @meta_patched_function.register(torch.matmul) -@meta_patched_function.register('matmul') # for built-in op @ +@meta_patched_function.register("matmul") # for built-in op @ def torch_matmul(input, other, *, out=None): # copied from huggingface.utils.fx d1 = input.dim() @@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None): @meta_patched_function.register(torch.abs) def torch_abs(input, *, out=None): - assert out is None, 'out is not supported yet' - return torch.empty(input.shape, device='meta') + assert out is None, "out is not supported yet" + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.bmm) @@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): - assert out is None, 'saving to out is not supported yet' - var = torch.empty(1).squeeze(0).to('meta') - mean = torch.empty(1).squeeze(0).to('meta') + assert out is None, "saving to out is not supported yet" + var = torch.empty(1).squeeze(0).to("meta") + mean = torch.empty(1).squeeze(0).to("meta") return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py index 8500e5c82508..26daf32a2afc 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -8,7 +8,6 @@ def _ntuple(n, name="parse"): - def parse(x): if isinstance(x, collections.abc.Iterable): return tuple(x) @@ -24,21 +23,21 @@ def parse(x): def _extract_kwargs(kwargs): - if 'stride' in kwargs: - stride = kwargs['stride'] + if "stride" in kwargs: + stride = kwargs["stride"] else: stride = 1 # TODO: process str type padding - if 'padding' in kwargs: - padding = kwargs['padding'] + if "padding" in kwargs: + padding = kwargs["padding"] else: padding = 0 - if 'dilation' in kwargs: - dilation = kwargs['dilation'] + if "dilation" in kwargs: + dilation = kwargs["dilation"] else: dilation = 1 - if 'output_padding' in kwargs: - output_padding = kwargs['output_padding'] + if "output_padding" in kwargs: + output_padding = kwargs["output_padding"] else: output_padding = 0 @@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs): c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv2d) @@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv3d) @@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose1d) @@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs): kernel_size = weight.shape[2:] l_in = input.shape[-1] c_out = weight.shape[1] - l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose2d) @@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs): kernel_size = weight.shape[2:] h_in, w_in = input.shape[-2:] c_out = weight.shape[1] - h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose3d) @@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs): kernel_size = weight.shape[2:] d_in, h_in, w_in = input.shape[-3:] c_out = weight.shape[1] - d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) - w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + - output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + h_out = math.floor( + (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py index 6d8d864ea29a..27a79f18590a 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -4,11 +4,7 @@ @meta_patched_function.register(torch.nn.functional.embedding) -def torch_nn_functional_embedding(input, - weight, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False): +def torch_nn_functional_embedding( + input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False +): return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py index e9e7eda6159c..8a6214990830 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -5,16 +5,11 @@ @meta_patched_function.register(torch.nn.functional.layer_norm) def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.nn.functional.batch_norm) -def torch_nn_func_batchnorm(input, - running_mean, - running_var, - weight=None, - bias=None, - training=False, - momentum=0.1, - eps=1e-05): - return torch.empty(input.shape, device='meta') +def torch_nn_func_batchnorm( + input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05 +): + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py index 4c171cb10991..7642934a409b 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -19,9 +19,9 @@ def to_concrete(t): return t def _slice_convert(slice_obj): - attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} + attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step} new_attrs = _slice_attr_convert(attrs) - attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) + attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"]) return slice(*attr_dict_to_tuple) def _slice_attr_convert(attrs): diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index b14ff10ce137..c61e1c4dc9e1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): shapes = [t.shape for t in tensors] shape = list(shapes[0]) concatenated_dim = sum(shape[dim] for shape in shapes) - final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :] return torch.empty(final_shape, device="meta") @meta_patched_function.register(torch.repeat_interleave) def torch_repeat_interleave(input, repeats, dim=None, output_size=None): - assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ - "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + assert isinstance(repeats, int) or isinstance( + repeats, torch.Tensor + ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" shape = list(input.shape) if dim is not None else [input.numel()] dim = dim if dim is not None else 0 @@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None) @meta_patched_function.register(torch.roll) def torch_roll(input, shifts, dims=None): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.full) def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - assert out is None, 'assigning result to out is not supported yet' - return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + assert out is None, "assigning result to out is not supported yet" + return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad) @meta_patched_function.register(torch.max) def torch_max(input, dim=None, keepdim=False, *, out=None): - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: if isinstance(dim, int): shape = list(input.shape) shape.pop(dim) if keepdim: shape.insert(dim, 1) - return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, - device='meta', - dtype=input.dtype) + return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty( + shape, device="meta", dtype=input.dtype + ) elif isinstance(dim, torch.Tensor): # when dim is a 0D or 1D tensor, it will maintain the same shape num_dims = dim.dim() if num_dims in [0, 1]: - return torch.empty_like(input, device='meta') + return torch.empty_like(input, device="meta") else: raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") else: - return torch.empty([], device='meta', dtype=input.dtype) + return torch.empty([], device="meta", dtype=input.dtype) @meta_patched_function.register(torch.Tensor.cpu) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py index e28e52585fff..3f40ec2a67ee 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -4,4 +4,4 @@ from .linear import * from .normalization import * from .pooling import * -from .rnn import * \ No newline at end of file +from .rnn import * diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py index d03da6588c1c..aa2ede187d37 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -10,4 +10,4 @@ @meta_patched_module.register(torch.nn.ReLU6) @meta_patched_module.register(torch.nn.PReLU) def torch_nn_non_linear_act(self, input): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py index cf9f3487aac9..35173a68a0be 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv2d) @@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv3d) @@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose1d) @@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose2d) @@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose3d) @@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py index 999e33b17c1c..f28647e9caa5 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -6,4 +6,4 @@ @meta_patched_module.register(torch.nn.Embedding) def torch_nn_embedding(self, input): result_shape = input.shape + (self.embedding_dim,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py index 56f13bf97532..97e6b0e96e83 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -6,5 +6,7 @@ @meta_patched_module.register(torch.nn.Linear) def torch_nn_linear(self, input): last_dim = input.shape[-1] - assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + assert ( + last_dim == self.in_features + ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch" return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index c21ff64cf3de..198e72e342b1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -23,6 +23,7 @@ def torch_nn_normalize(self, input): try: import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py index 7ce23fbf7ac9..450586d02f8f 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -8,7 +8,7 @@ @meta_patched_module.register(torch.nn.AvgPool1d) def torch_nn_avgpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -25,13 +25,13 @@ def _convert_int_to_list(item): l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool2d) def torch_nn_avgpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -52,13 +52,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool3d) def torch_nn_avgpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -81,13 +81,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool1d) def torch_nn_maxpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -105,13 +105,13 @@ def _convert_int_to_list(item): l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool2d) def torch_nn_maxpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -133,13 +133,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool3d) def torch_nn_maxpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -163,7 +163,7 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) @@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-1]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) @@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-2]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) @@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-3]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py index ee15ca34162e..bfb7ed171186 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from ...registry import meta_patched_module @@ -8,9 +6,11 @@ @meta_patched_module.register(torch.nn.GRU) @meta_patched_module.register(torch.nn.RNN) def torch_nn_rnn(self, input, hx): - assert input.shape[ - -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' - assert hx.shape[ - -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + assert ( + input.shape[-1] == self.input_size + ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch" + assert ( + hx.shape[-1] == self.hidden_size + ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch" d = 2 if self.bidirectional else 1 return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py index 12fc6de73d44..80b3868bb4fe 100644 --- a/colossalai/fx/tracer/registry.py +++ b/colossalai/fx/tracer/registry.py @@ -1,11 +1,9 @@ class PatchRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,8 +19,8 @@ def has(self, source): return source in self.store -meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') -meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') -bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') -bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') -bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') +meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution") +meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution") +bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition") +bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition") +bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition") diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 28965a1b8e74..d9cb587b5d39 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -29,7 +29,7 @@ meta_patched_module, ) -__all__ = ['ColoTracer'] +__all__ = ["ColoTracer"] class TracerType(enum.Enum): @@ -103,7 +103,7 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) else: @@ -160,22 +160,27 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if n not in parameter_proxy_cache: kwargs = {} if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: - kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else - lambda node: ParameterProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), - parameter_proxy_cache) + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) if maybe_buffer_proxy is not None: return maybe_buffer_proxy @@ -190,7 +195,7 @@ def call_module(self, m, forward, args, kwargs): # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, # we should treat it as leaf module as well if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): - return self.create_proxy('call_module', module_qualified_name, args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) else: return forward(*args, **kwargs) @@ -211,7 +216,6 @@ def _configure_tracer_type(self, tracer_type: TracerType): raise ValueError(f"Unrecognized tracer type {tracer_type}") def _meta_data_computing(self, kind, target, args, kwargs): - if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: meta_out = self.meta_args[target] return meta_out @@ -235,8 +239,9 @@ def _meta_data_computing(self, kind, target, args, kwargs): # Therefore, I need to record the nn.parameter.Parameter attribute for the operation # added by the bias addition manipulation following the get_attr node. convert_to_parameter = False - if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], - torch.nn.parameter.Parameter): + if target in (torch.transpose, torch.reshape) and isinstance( + args_metas[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # fetch patched function if meta_patched_function.has(target): @@ -309,10 +314,12 @@ def _meta_data_computing(self, kind, target, args, kwargs): return meta_out - def trace(self, - root: nn.Module, - concrete_args: Optional[Dict[str, Tensor]] = None, - meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + def trace( + self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None, + ) -> Graph: """ Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. @@ -341,9 +348,7 @@ def trace(self, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names @@ -354,7 +359,8 @@ def _check_arg_name_valid(names): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -363,11 +369,13 @@ def _check_arg_name_valid(names): def _check_kwargs(kwargs, should_be_meta: bool): for k, v in kwargs.items(): if not should_be_meta: - assert not torch.is_tensor(v) or not v.is_meta, \ - f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + assert ( + not torch.is_tensor(v) or not v.is_meta + ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer" else: - assert v.is_meta == should_be_meta, \ - f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + assert ( + v.is_meta == should_be_meta + ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer" _check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(meta_args, should_be_meta=True) @@ -442,7 +450,6 @@ def trace_activation_checkpoint(self, enabled: bool): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): # signal that the current tracing occurs within activation checkpoint part @@ -455,7 +462,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -470,12 +478,11 @@ def create_node(self, *args, **kwargs) -> Node: if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node def wrap_tensor_constructor_method(target): - def look_for_proxy(*args, **kwargs): # find in pos vars for arg in args: @@ -518,12 +525,10 @@ def wrapper(*args, **kwargs): for method in magic_methods: def _scope(method): - def impl(*args, **kwargs): - tracer = args[0].tracer target = getattr(operator, method) - proxy = tracer.create_proxy('call_function', target, args, kwargs) + proxy = tracer.create_proxy("call_function", target, args, kwargs) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) proxy = ColoProxy(proxy.node) @@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name): def impl(self, rhs): target = getattr(operator, orig_method_name) - proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {}) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) proxy = ColoProxy(proxy.node) diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index e467b4c73e6b..112b920ba158 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,4 +1,4 @@ from .engine import TPInferEngine from .kvcache_manager import MemoryManager -__all__ = ['MemoryManager', 'TPInferEngine'] +__all__ = ["MemoryManager", "TPInferEngine"] diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index 2bff9317283e..ac185f1b6529 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -1,6 +1,5 @@ # might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later from dataclasses import dataclass -from typing import Any import torch @@ -31,7 +30,7 @@ class BatchInferState: decode_mem_index: torch.Tensor = None decode_layer_id: int = None - device: torch.device = torch.device('cuda') + device: torch.device = torch.device("cuda") @property def total_token_num(self): @@ -43,13 +42,15 @@ def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager @staticmethod - def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, - alloc_mem_index: torch.Tensor): - """ in-place update block loc mapping based on the sequence length of the inputs in current bath""" + def init_block_loc( + b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor + ): + """in-place update block loc mapping based on the sequence length of the inputs in current bath""" start_index = 0 seq_len_numpy = seq_len.cpu().numpy() for i, cur_seq_len in enumerate(seq_len_numpy): - b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + - cur_seq_len] + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ + start_index : start_index + cur_seq_len + ] start_index += cur_seq_len return diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a5a55702ade0..1335f13d66b8 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch import torch.nn as nn @@ -15,7 +15,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] +_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"] class TPInferEngine: @@ -39,14 +39,16 @@ class TPInferEngine: >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) """ - def __init__(self, - model: nn.Module, - shard_config: ShardConfig, - max_batch_size: int, - max_input_len: int, - max_output_len: int, - dtype: torch.dtype = torch.float16, - device: str = 'cuda') -> None: + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + ) -> None: self.max_batch_size = max_batch_size self.max_input_len = max_input_len self.max_output_len = max_output_len @@ -63,7 +65,7 @@ def __init__(self, self.head_num = model.config.num_attention_heads self.layer_num = model.config.num_hidden_layers - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None self.shard_config = shard_config @@ -74,9 +76,10 @@ def __init__(self, def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" - self.head_num //= self.tp_size # update sharded number of heads - self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, - self.layer_num) + self.head_num //= self.tp_size # update sharded number of heads + self.cache_manager = MemoryManager( + self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num + ) def _optimize_model(self, model: nn.Module) -> None: """ @@ -90,7 +93,7 @@ def _optimize_model(self, model: nn.Module) -> None: self._shard_model_by(shardformer, model) def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: - """ Prepare the engine with a given ShardConfig. + """Prepare the engine with a given ShardConfig. Args: shard_config (ShardConfig): shard config given to specify settings of the engine. @@ -118,9 +121,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) return shard_config def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: - """ Shard original model by the given ShardFormer and store the sharded model. """ - assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ - "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" + """Shard original model by the given ShardFormer and store the sharded model.""" + assert ( + self.tp_size == shardformer.shard_config.tensor_parallel_size + ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) @@ -147,7 +151,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].cuda() - if 'max_new_tokens' not in generate_kwargs: + if "max_new_tokens" not in generate_kwargs: generate_kwargs.update(max_new_tokens=self.max_output_len) return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) @@ -176,18 +180,18 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = None if isinstance(inputs, (BatchEncoding, dict)): - input_ids_list = inputs['input_ids'] - attention_mask = inputs['attention_mask'] + input_ids_list = inputs["input_ids"] + attention_mask = inputs["attention_mask"] else: input_ids_list = inputs - if isinstance(input_ids_list[0], int): # for a single input + if isinstance(input_ids_list[0], int): # for a single input input_ids_list = [input_ids_list] attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 max_len_in_batch = -1 @@ -210,10 +214,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to('cuda') - batch_infer_state.start_loc = seq_start_indexes.to('cuda') + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 @@ -248,7 +252,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch model = self.model.model elif isinstance(model, BloomForCausalLM): model = self.model.transformer - setattr(model, 'infer_state', batch_infer_state) + setattr(model, "infer_state", batch_infer_state) outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -262,14 +266,15 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch # as an arg into model.forward. # It requires rewriting model generate and replacing model forward. @torch.no_grad() - def _generate_by_pass_infer_state(self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: - + def _generate_by_pass_infer_state( + self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs, + ) -> torch.Tensor: raise NotImplementedError("generate by passing BatchInferState is not implemented.") # might want to use in rewritten generate method: use after model.forward diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index 274c01841279..e74a3a491a7b 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -19,13 +19,15 @@ class MemoryManager: device: device used to store the key and value cache """ - def __init__(self, - size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: torch.device = torch.device('cuda')): + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device("cuda"), + ): self.logger = logging.get_logger(__name__) self.available_size = size self.past_key_values_length = 0 @@ -33,13 +35,13 @@ def __init__(self, self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) def _init_mem_states(self, size, device): - """ Initialize tensors used to manage memory states """ + """Initialize tensors used to manage memory states""" self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) self.indexes = torch.arange(0, size, dtype=torch.long, device=device) def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): - """ Initialize key buffer and value buffer on specified device """ + """Initialize key buffer and value buffer on specified device""" self.key_buffer = [ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) ] @@ -49,10 +51,9 @@ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): @torch.no_grad() def alloc(self, required_size): - """ allocate space of required_size by providing indexes representing available physical spaces """ + """allocate space of required_size by providing indexes representing available physical spaces""" if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") return None torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) @@ -63,23 +64,25 @@ def alloc(self, required_size): @torch.no_grad() def alloc_contiguous(self, required_size): - """ allocate contiguous space of required_size """ + """allocate contiguous space of required_size""" if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") return None torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) sum_size = len(self.mem_cum_sum) - loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + - 1] + self.mem_state[0:sum_size - - required_size + 1] - can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] + loc_sums = ( + self.mem_cum_sum[required_size - 1 :] + - self.mem_cum_sum[0 : sum_size - required_size + 1] + + self.mem_state[0 : sum_size - required_size + 1] + ) + can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info(f"No enough contiguous cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.info( + f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + ) return None start_loc = can_used_loc[0] - select_index = self.indexes[start_loc:start_loc + required_size] + select_index = self.indexes[start_loc : start_loc + required_size] self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() @@ -88,13 +91,13 @@ def alloc_contiguous(self, required_size): @torch.no_grad() def free(self, free_index): - """ free memory by updating memory states based on given indexes """ + """free memory by updating memory states based on given indexes""" self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 @torch.no_grad() def free_all(self): - """ free all memory by updating memory states """ + """free all memory by updating memory states""" self.available_size = len(self.mem_state) self.mem_state[:] = 1 self.past_key_values_length = 0 diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 7a98b033f37e..27cec5452ece 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,4 +1,4 @@ from .bloom import BloomInferenceForwards from .llama import LlamaInferenceForwards -__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards'] +__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"] diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index ba5eadc92be8..27a26caabefa 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -1,6 +1,6 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.distributed as dist @@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16): """ def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) return [start * start**i for i in range(n)] def get_slopes(n): if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) slopes_double = get_slopes(2 * closest_power_of_2) - slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] + slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] return slopes_combined slopes = get_slopes(n_head) @@ -72,7 +72,6 @@ def bloom_model_forward( infer_state: Optional[BatchInferState] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - logger = logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: @@ -86,8 +85,9 @@ def bloom_model_forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -122,14 +122,15 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # NOTE determine if BatchInferState is passed in via arg # if not, get the attr binded to the model # We might wantto remove setattr later if infer_state is None: - assert hasattr(self, 'infer_state') + assert hasattr(self, "infer_state") infer_state = self.infer_state # Compute alibi tensor: check build_alibi_tensor documentation @@ -146,10 +147,11 @@ def bloom_model_forward( if use_cache and seq_length != 1: # prefill stage - infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) + BatchInferState.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -182,8 +184,11 @@ def bloom_model_forward( # alibi = generate_alibi(self.num_heads).contiguous().cuda() tp_size = dist.get_world_size() curr_tp_rank = dist.get_rank() - alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * - self.num_heads].cuda() + alibi = ( + generate_alibi(self.num_heads * tp_size) + .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] + .cuda() + ) causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), @@ -197,7 +202,6 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: # NOTE: currently our KV cache manager does not handle this condition def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -250,32 +254,34 @@ def custom_forward(*inputs): return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, # should always be (None, None, ..., None) + past_key_values=presents, # should always be (None, None, ..., None) hidden_states=all_hidden_states, attentions=all_self_attentions, ) @staticmethod - def bloom_for_causal_lm_forward(self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments): + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ - logger = logging.get_logger(__name__) + logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` @@ -289,17 +295,19 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state) + transformer_outputs = BloomInferenceForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -314,8 +322,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -353,11 +362,13 @@ def bloom_for_causal_lm_prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - model_inputs.update({ - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod @@ -416,7 +427,7 @@ def bloom_block_forward( else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, present, attentions + return outputs # hidden_states, present, attentions @staticmethod def bloom_attention_forward( @@ -431,20 +442,19 @@ def bloom_attention_forward( output_attentions: bool = False, infer_state: Optional[BatchInferState] = None, ): - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, q_length, H, D_HEAD = query_layer.shape - k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 + if layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_length # += 1 if infer_state.is_context_stage: # context process @@ -471,9 +481,11 @@ def bloom_attention_forward( if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_v = infer_state.cache_manager.value_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_k.copy_(k) cache_v.copy_(v) else: @@ -486,8 +498,17 @@ def bloom_attention_forward( b_loc = infer_state.block_loc b_seq_len = infer_state.seq_len output = torch.empty_like(q) - token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, - b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi) + token_attention_fwd( + q, + mem_manager.key_buffer[layer_id], + mem_manager.value_buffer[layer_id], + output, + b_loc, + b_start_loc, + b_seq_len, + infer_state.cache_manager.past_key_values_length, + alibi, + ) context_layer = output.view(batch_size, q_length, H * D_HEAD) @@ -504,8 +525,8 @@ def bloom_attention_forward( output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 07b73a6f4ca6..4795162f1980 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple -import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm @@ -15,6 +14,7 @@ try: from vllm import layernorm_ops, pos_encoding_ops + rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True @@ -29,17 +29,17 @@ def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -71,8 +71,7 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - - batch_size = input_ids.shape[0] # input_ids.shape[0] + batch_size = input_ids.shape[0] # input_ids.shape[0] infer_state = self.infer_state @@ -103,10 +102,11 @@ def llama_model_forward( if use_cache and seq_length != 1: # NOTE assuem prefill stage # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -129,20 +129,20 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) @@ -153,12 +153,13 @@ def llama_model_forward( # embed positions if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -216,7 +217,6 @@ def llama_decoder_layer_forward( use_cache: Optional[bool] = False, infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -261,7 +261,6 @@ def llama_flash_attn_kvcache_forward( use_cache: bool = False, infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - assert use_cache is True, "use_cache should be set to True using this llama attention" bsz, q_len, _ = hidden_states.size() @@ -277,8 +276,8 @@ def llama_flash_attn_kvcache_forward( # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) @@ -299,38 +298,62 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # first token generation # copy key and value calculated in current step to memory manager - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, - infer_state.cache_manager) + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) attn_output = torch.empty_like(query_states) - llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc, - infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) else: - if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_k.copy_(key_states) cache_v.copy_(value_states) else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, - infer_state.decode_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, - infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length) + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -341,7 +364,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py index 48f8db62c32a..fcb1b6a3bd8f 100644 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -1,4 +1,4 @@ from .bloom import BloomModelInferPolicy from .llama import LlamaModelInferPolicy -__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy'] +__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"] diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index cae43aa20421..2d18a3922c1e 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -9,6 +9,7 @@ try: from colossalai.kernel.triton import layer_norm + HAS_TRITON_NORM = True except: print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") @@ -27,40 +28,40 @@ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): class BloomModelInferPolicy(BloomForCausalLMPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + policy = super().module_policy() # NOTE set inference mode to shard config self.shard_config._infer() method_replacement = { - 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward, - 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + "forward": BloomInferenceForwards.bloom_for_causal_lm_forward, + "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation, } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomForCausalLM) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) - method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} + method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) - method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} + method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) - method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomAttention) + method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomAttention + ) if HAS_TRITON_NORM: infer_method = get_triton_layernorm_forward() - method_replacement = {'forward': partial(infer_method)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LayerNorm) + method_replacement = {"forward": partial(infer_method)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LayerNorm + ) return policy diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 4844415d612c..9bbb547dbcae 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -10,6 +10,7 @@ try: from colossalai.kernel.triton import rmsnorm_forward + HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -28,7 +29,6 @@ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - def __init__(self) -> None: super().__init__() @@ -37,20 +37,20 @@ def module_policy(self): self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {'forward': partial(infer_forward)} + method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaDecoderLayer) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaAttention) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) infer_forward = None if HAS_TRITON_RMSNORM: @@ -60,9 +60,9 @@ def module_policy(self): infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) return policy diff --git a/colossalai/initialize.py b/colossalai/initialize.py index b8718abc80bd..aac57d34a2c1 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -14,15 +14,17 @@ from colossalai.utils import set_device, set_seed -def launch(config: Union[str, Path, Config, Dict], - rank: int, - world_size: int, - host: str, - port: int, - backend: str = 'nccl', - local_rank: int = None, - seed: int = 1024, - verbose: bool = True): +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict], warnings.warn("`config` is deprecated and will be removed soon.") # init default process group - init_method = f'tcp://[{host}]:{port}' + init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device @@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict], if verbose: logger = get_dist_logger() - logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0]) + logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0]) -def launch_from_slurm(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM @@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" ) - launch(config=config, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_openmpi(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI @@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_torch(config: Union[str, Path, Config, Dict], - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py index 1c3199fc1aff..98b21c9c02c1 100644 --- a/colossalai/interface/__init__.py +++ b/colossalai/interface/__init__.py @@ -1,4 +1,4 @@ from .model import AMPModelMixin, ModelWrapper from .optimizer import OptimizerWrapper -__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin'] +__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index 7b3d9435d255..58df09b853ee 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -26,11 +26,9 @@ def forward(self, *args, **kwargs): class AMPModelMixin: - """This mixin class defines the interface for AMP training. - """ + """This mixin class defines the interface for AMP training.""" def update_master_params(self): """ Update the master parameters for AMP training. """ - pass diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index bc270b1d9c89..95d11087bece 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -22,7 +22,7 @@ def parameters(self): params = [] for group in self.param_groups: - params += group['params'] + params += group["params"] return params @property @@ -82,12 +82,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: """ nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> Tensor: """ Clips gradient norm of an iterable of parameters. @@ -113,7 +115,8 @@ def scale_loss(self, loss: Tensor): loss (Tensor): The loss to be scaled. """ raise NotImplementedError( - "The method scale_loss is only available for optimizers with mixed precision training") + "The method scale_loss is only available for optimizers with mixed precision training" + ) def unscale_grad(self): """ @@ -122,7 +125,8 @@ def unscale_grad(self): Note: Only available for optimizers with mixed precision training. """ raise NotImplementedError( - "The method unscale_grad is only available for optimizers with mixed precision training") + "The method unscale_grad is only available for optimizers with mixed precision training" + ) def unwrap(self): """ diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index e0136d86e561..f8a974b5fb26 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -4,6 +4,10 @@ from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax __all__ = [ - 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention', - 'AttnMaskType' + "LayerNorm", + "MultiHeadAttention", + "FusedScaleMaskSoftmax", + "ScaledUpperTriangMaskedSoftmax", + "ColoAttention", + "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h index 00066dc95475..a62beef91a8a 100644 --- a/colossalai/kernel/cuda_native/csrc/compat.h +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -7,4 +7,4 @@ #define DATA_PTR data_ptr #else #define DATA_PTR data -#endif \ No newline at end of file +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu index 26efa2ad6f31..9a6a8ebc3983 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -1,7 +1,6 @@ #include #include - #include "cuda_util.h" /* GPU function guard */ diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index a39a6dae0f7f..ce0b017f12e1 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,1002 +1,1002 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); +#include +#include + +#include "kernels.h" + +#include + + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu index bc90c54c0a00..625b02cd25d9 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -1,232 +1,232 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h index 563a7fe284a3..025fbf3f8f15 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -1,96 +1,96 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h index fbb9c5465c24..735e1363cc46 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -3,10 +3,11 @@ #include #include #include -#include #include #include +#include + #define MAX_THREADS 1024 #define WARP_SIZE 32 @@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, } /* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int -flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { +__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, + int id4, int dim2, int dim3, + int dim4) { // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; int res = id4; @@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, } /* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, - int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { +__forceinline__ __host__ __device__ void decompose_6dim( + int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, + int *id1, int *id2, int *id3, int *id4, int *id5) { *id5 = src % dim5; src /= dim5; @@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, } /* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, - int *id1, int *id2, int *id3, int *id4) { +__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, + int dim2, int dim3, + int dim4, int *id0, + int *id1, int *id2, + int *id3, int *id4) { *id4 = src % dim4; src /= dim4; @@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, } /* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { *id2 = src % dim2; src /= dim2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h index ded5c0fdcbee..a7767e187ffc 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -1,64 +1,65 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template class Normalize_Layer { -public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - -private: - Config config_; - T *vars_; - T *means_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Normalize_Layer { + public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + + private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index ec447ad84c54..b917abaf0336 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -1,42 +1,42 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index 3e61d4e35832..e2f1869b165e 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -1,1169 +1,1172 @@ -#include "block_reduce.h" -#include "kernels.h" -#include - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template __forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void -ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, - const T *inp_or_out, const T *gamma, const T *betta, - const T *vars, const T *means, int rows, int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template +__forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, + const T *out_grad, const T *inp_or_out, + const T *gamma, const T *betta, + const T *vars, const T *means, int rows, + int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu index 98af433fe397..3862a699d3c3 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -1,365 +1,365 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); +#include +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu index d03084b22e12..04de3c092ee0 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -1,312 +1,314 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void -bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void -bias_add_transform_20314<__half>(__half *output, const __half *input, - const __half *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void bias_add_transform_20314(float *output, + const float *input, + const float *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 4690277e63db..15a07bb0c7ac 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -138,4 +138,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu index ad7066bbd9df..72b84d6ca40f 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -680,4 +680,4 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, gamma != NULL ? grad_beta->DATA_PTR() : NULL);) -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 61c8a725052f..8c0b89eb06d1 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -1,97 +1,97 @@ -#include - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -torch::Tensor moe_dispatch_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); -} - -torch::Tensor moe_dispatch_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); -} - -torch::Tensor moe_combine_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, - dest_idx); -} - -std::vector moe_combine_backward(int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(tokens_grad); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, - logits, mask, dest_idx); -} - -torch::Tensor moe_cumsum(torch::Tensor mask) { - CHECK_INPUT(mask); - return cumsum_sub_one_in_dim0(mask); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); - m.def("dispatch_forward", &moe_dispatch_forward, - "Forward operation in MoE dispatch function"); - m.def("dispatch_backward", &moe_dispatch_backward, - "Backward operation in MoE dispatch function"); - m.def("combine_forward", &moe_combine_forward, - "Combine operation in MoE combine function"); - m.def("combine_backward", &moe_combine_backward, - "Combine operation in MoE combine function"); -} +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index 0454377a2fad..66c1e6bd260e 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,659 +1,659 @@ -#include -#include -#include - -#include - -#include "block_reduce.h" - -template -__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, pack); - BlockStore(ts_store).Store(src_row + idx, pack); - } -} - -template -__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row1 + idx, pack); - BlockStore(ts_store).Store(dst_row2 + idx, pack); - } -} - -template -__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row1 + idx, pack1); - BlockLoad(ts_load).Load(dst_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] += pack2[i]; - } - - BlockStore(ts_store).Store(src_row + idx, pack1); - } -} - -template -__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack[i] *= weight; - } - - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, - T *weight_grad, const T weight, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens[pack_size]; - float thread_sum = 0; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row + idx, tokens); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum += grad[i] * tokens[i]; - grad[i] *= weight; - } - - BlockStore(ts_store).Store(src_row + idx, grad); - } - - blockReduce(&thread_sum); - - if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); -} - -template -__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, - const T weight1, const T weight2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row1 + idx, pack1); - BlockLoad(ts_load).Load(src_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; - } - - BlockStore(ts_store).Store(dst_row + idx, pack1); - } -} - -template -__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, - T *tks_row1, T *tks_row2, T *weight_grad1, - T *weight_grad2, const T weight1, - const T weight2, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], - sgrad2[pack_size]; - float thread_sum[2] = {0, 0}; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); - BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum[0] += grad[i] * tokens1[i]; - thread_sum[1] += grad[i] * tokens2[i]; - sgrad1[i] = weight1 * grad[i]; - sgrad2[i] = weight2 * grad[i]; - } - - BlockStore(ts_store).Store(src_row1 + idx, sgrad1); - BlockStore(ts_store).Store(src_row2 + idx, sgrad2); - } - - blockReduce(thread_sum); - - if (threadIdx.x == 0) - *weight_grad1 = static_cast(thread_sum[0]); - else if (threadIdx.x == 1) - *weight_grad2 = static_cast(thread_sum[1]); -} - -// DISPATCH KERNELS -------------------------------- - -template -__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_fwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_fwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_fwd(src_row, dst_row2, cols); - else - return; -} - -template -__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_bwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_bwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_bwd(src_row, dst_row2, cols); - else - return; -} - -template -__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, - int *mask1, int *mask2, int *dest1, - int *dest2, const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_fwd_selector( - batch_tokens + (row * h), expert_input + (dest1[row] * h), - expert_input + (dest2[row] * h), h, mask1[row], indicator2); -} - -template -__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_bwd_selector( - tokens_grad + (row * h), expert_grad + (dest1[row] * h), - expert_grad + (dest2[row] * h), h, mask1[row], indicator2); -} - -// COMBINE KERNELS -------------------------------- - -template -__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_fwd(src_row1, src_row2, dst_row, - weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_fwd(src_row1, dst_row, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_fwd(src_row2, dst_row, weight2, cols); - else - return; -} - -template -__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, T *tks_row1, T *tks_row2, - T *wt_grad1, T *wt_grad2, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_bwd(src_row1, src_row2, dst_row, - tks_row1, tks_row2, wt_grad1, - wt_grad2, weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_bwd(src_row1, dst_row, tks_row1, - wt_grad1, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_bwd(src_row2, dst_row, tks_row2, - wt_grad2, weight2, cols); - else - return; -} - -template -__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, - T *logits, int *mask1, int *mask2, int *dest1, - int *dest2, const int e, const int c, - const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e); - moe_cb_fwd_selector( - expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), - combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], - indicator2); -} - -template -__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, - T *logits, T *logits_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); - moe_cb_bwd_selector( - expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), - tokens_grad + (row * h), h, tks + (dest1[row] * h), - tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], - row_log[eid2], mask1[row], indicator2); -} - -// CUMSUM KERNEL -------------------------------- - -template -__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, - const int e) { - assert(s % pack_size == 0); - constexpr int bpack_size = block_size * pack_size; - int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; - __shared__ int temp[block_size + 1]; - int pack[pack_size]; - - for (int idx = 0; idx < s; idx += bpack_size) { - int offset = 1; - - if (idx + tps < s) { - temp[tid] = inputs[tps * e + bid]; -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - pack[i] = inputs[(tps + i) * e + bid]; - } -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - temp[tid] += pack[i]; - } - } - - for (int i = block_size >> 1; i > 0; i >>= 1) { - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1; - temp[j + offset] += temp[j]; - } - offset <<= 1; - } - - if (tid == 0) { - temp[block_size] = temp[block_size - 1]; - temp[block_size - 1] = 0; - } - - for (int i = 1; i < block_size; i <<= 1) { - offset >>= 1; - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; - temp[j] = temp[k]; - temp[k] += ts; - } - } - __syncthreads(); - - if (tid == 0) temp[0] = temp[block_size]; - __syncthreads(); - - if (idx + tps < s) { - temp[tid + 1] += last_sum; -#pragma unroll - for (int i = pack_size - 1; i > 0; --i) { - outputs[(tps + i) * e + bid] = temp[tid + 1]; - temp[tid + 1] -= pack[i]; - } - outputs[tps * e + bid] = temp[tid + 1]; - } - __syncthreads(); - - last_sum += temp[0]; - inputs += bpack_size * e; - outputs += bpack_size * e; - } -} - -// LAUNCH FUNCTIONS -------------------------------- - -template -void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, - int *mask2, int *dest1, int *dest2, const int s, - const int h) { - if (h < 256) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); -} - -template -void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, - int *dest1, int *dest2, const int s, const int h) { - if (h < 256) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); -} - -template -void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, - int *mask1, int *mask2, int *dest1, int *dest2, - const int s, const int e, const int c, const int h) { - if (h < 256) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 512) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 1024) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 2048) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, - dest2, e, c, h); -} - -template -void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, - T *logits_grad, int *mask1, int *mask2, int *dest1, - int *dest2, const int s, const int e, const int c, - const int h) { - if (h < 256) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - else // if (h < 512) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - // else if (h < 1024) - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); - // else - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); -} - -void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { - if (s <= 256) - cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); - else if (s <= 512) - cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); - else if (s <= 1024) - cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); - else if (s <= 2048) - cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); - else - cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); -} - -// API FUNCTIONS -------------------------------- - -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ - } - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {ec, h}, - torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - batch_tokens.scalar_type(), "moe dispatch forward", - moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_grad.scalar_type(), "moe dispatch backward", - moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(expert_tokens.dtype() == logits.dtype()); - - auto res = torch::zeros( - {s, h}, - torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_tokens.scalar_type(), "moe combine forward", - moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return res; -} - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(tokens_grad.dtype() == expert_tokens.dtype()); - assert(expert_tokens.dtype() == logits.dtype()); - - auto egrad = torch::zeros( - {e * c, h}, - torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), - wgrad = torch::zeros( - {s, e}, torch::dtype(logits.dtype()).device(logits.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - tokens_grad.scalar_type(), "moe combine backward", - moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return {egrad, wgrad}; -} - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { - assert(mask.dim() == 2); - assert(mask.dtype() == torch::kInt32); - - const int s = mask.size(0), e = mask.size(1); - auto res = - torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); - - return res; -} +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 49ab83e8fc81..85f935152f8a 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -379,4 +379,4 @@ void multi_tensor_norm_out_cuda( norm_type, alpha, beta); return; -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 54c4220190d8..63771cf40bcb 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -351,4 +351,4 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu index 360485dcd02f..2f58a0f16dce 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -122,4 +122,4 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index 35f2c9b4ed15..7f48dbd5d497 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -164,4 +164,4 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, } AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 4ae3c853ca5e..8c2982b0cff9 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,82 +3,68 @@ #include #include + #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h index 1583030b8235..d3e6f04e6093 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -4,12 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include -#include namespace { @@ -17,37 +17,53 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -55,438 +71,468 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } -template +template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; } -} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h index 3af487f9de0f..54c8e9133a1b 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -4,11 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include namespace { @@ -16,53 +17,78 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} template __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -70,431 +96,505 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) */ -template +template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } } + } } -template +template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } + } } -} // end of anonymous namespace +} // end of anonymous namespace -template +template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; } + } } -template +template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 40355a41ed0d..c7d2a3a45022 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -18,7 +18,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -30,7 +29,6 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm @@ -43,17 +41,14 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = layer_norm.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = layer_norm.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() @@ -66,13 +61,11 @@ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): self.reset_parameters() def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input): - return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) def __repr__(self): - return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' + return f"MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})" diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py index 21fddd512957..cad36e598d14 100644 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ b/colossalai/kernel/cuda_native/mha/__init__.py @@ -1,3 +1,3 @@ from .mha import ColoAttention -__all__ = ['ColoAttention'] +__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py index 6a8d74f70c1d..9ee83915b1b4 100644 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -8,7 +8,7 @@ def is_ampere_or_better_gpu(): if torch.cuda.is_available(): device = torch.device("cuda") properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer + if properties.major >= 8: # Ampere GPUs or newer return True return False @@ -18,30 +18,33 @@ def is_ampere_or_better_gpu(): if is_ampere_or_better_gpu(): HAS_FLASH_ATTN = True else: - warnings.warn('FlashAttention only supports Ampere GPUs or newer.') + warnings.warn("FlashAttention only supports Ampere GPUs or newer.") HAS_FLASH_ATTN = False try: from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + HAS_FLASH_ATTN = True except ImportError: - warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention') + warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: - from einops import rearrange + pass from .utils import SeqLenInfo - def flash_attention(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0., - scale: float = None, - causal: bool = False, - padded: bool = False): + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): """ Arguments: q: (batch, q_seqlen, nheads, headdim) @@ -60,9 +63,18 @@ def flash_attention(q: torch.Tensor, if seq_len_info_kv == None: seq_len_info_kv = seq_len_info_q - attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale, - causal) + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) else: attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py index 8a898080877c..649e74d61bab 100644 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py @@ -9,9 +9,10 @@ LowerTriangularMask, LowerTriangularMaskWithTensorBias, ) + HAS_MEM_EFF_ATTN = True except ImportError: - warnings.warn('please install xformers from https://github.com/facebookresearch/xformers') + warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") HAS_MEM_EFF_ATTN = False if HAS_MEM_EFF_ATTN: @@ -29,30 +30,30 @@ for op in MemoryEfficientAttentionCutlassOp: allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - def mem_eff_attention(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0., - scale: float = None, - causal: bool = False, - padded: bool = False): - + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): attn_bias = None - if padded: # bert style + if padded: # bert style if not causal: attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style + elif causal: # gpt style attn_bias = LowerTriangularMask() - if bias is not None: # alibi / relative position embedding + if bias is not None: # alibi / relative position embedding assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, \ - "attention with bias is only supported for causal attention so far." + assert causal, "attention with bias is only supported for causal attention so far." attn_bias = attn_bias.add_bias(bias) if padded: diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py index 8f449a138c51..1c778439d33f 100644 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -2,7 +2,6 @@ from typing import Optional import torch -import torch.nn.functional as F from einops import rearrange from ..scaled_softmax import AttnMaskType @@ -17,11 +16,11 @@ class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): super().__init__() - assert embed_dim % num_heads == 0, \ - f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." if scale is not None: self.scale = scale else: @@ -39,14 +38,15 @@ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: return Repad.apply(tensor, indices, batch_size, seq_len) - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None): - + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): attn = None if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: attn = flash_attention @@ -62,18 +62,21 @@ def forward(self, seq_len_info_kv = None if padded: # bert style, unpad process - assert attn_mask is not None, \ - f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, \ - "attention mask is supposed to have shape (batch_size, seq_len), " + \ - f"but got {attn_mask.dim()} dimensions." + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) # bert style if tgt_len == src_len: seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) if batch_size > 1: - query, key, value = self.unpad(torch.stack([query, key, value], dim=2), - seq_len_info_q.indices).unbind(dim=1) + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) else: query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) seq_len_info_kv = seq_len_info_q @@ -82,26 +85,29 @@ def forward(self, seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) if batch_size > 1: query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), - seq_len_info_kv.indices).unbind(dim=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) else: query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - out = attn(query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded) + out = attn( + query, + key, + value, + seq_len_info_q, + seq_len_info_kv, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) # repad if padded: if batch_size > 1: out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, '(b s) h d -> b s h d', b=batch_size) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - out = rearrange(out, 'b s h d -> b s (h d)') + out = rearrange(out, "b s h d -> b s (h d)") return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py index e3e431fa7e99..fe31921b961b 100644 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -20,18 +20,18 @@ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): # [b, s, ...] assert tensor.ndim >= 3 ctx.bsz = tensor.shape[0] - out = rearrange(tensor, 'b s ... -> (b s) ...') + out = rearrange(tensor, "b s ... -> (b s) ...") ctx.shape = out.shape # [ntokens, ...] return out[indices] @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors # [ntokens, ...] grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) grad[indices] = grad_output - grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) # [b, s, ...] return grad, None @@ -54,7 +54,7 @@ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, s @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors # [b*s, ...] grad = grad_output[indices] # [ntokens, ...] diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 69246f2f3854..87afc1862847 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -36,34 +36,64 @@ def calc_offset(sizes): @dataclass class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 precision class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, - norm_bias, config): + def forward( + ctx, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config, + ): cuda_module = colossal_multihead_attention - forward_func = (cuda_module.multihead_attention_fw_fp16 - if config.fp16 else cuda_module.multihead_attention_fw_fp32) + forward_func = ( + cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 + ) if config.fp16: input = input.to(torch.half) input_mask = input_mask.to(torch.half) - (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) + (output,) = forward_func( + config.layer_id, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config.training, + config.norm_first, + ) if config.is_grad_enabled and config.training: - ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias) + ctx.save_for_backward( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) ctx.config = config return output @@ -72,11 +102,21 @@ def backward(ctx, grad_output): assert ctx.config.training cuda_module = colossal_multihead_attention - backward_func = (cuda_module.multihead_attention_bw_fp16 - if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + backward_func = ( + cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 + ) - output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ - out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + ( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) = ctx.saved_tensors grad_input = None grad_in_proj_weight = None @@ -91,13 +131,39 @@ def backward(ctx, grad_output): output = output.to(torch.half) input = input.to(torch.half) input_mask = input_mask.to(torch.half) - grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ - grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( - ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, - in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ( + grad_input, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + ) = backward_func( + ctx.config.layer_id, + grad_output, + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) - return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, - grad_norm_weight, grad_norm_bias, None) + return ( + grad_input, + None, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + None, + ) class MultiHeadAttention(nn.Module): @@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): super(MultiHeadAttention, self).__init__() - self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, - fp16) + self.config = Config( + batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 + ) check_config(self.config) self.pg = pg self.pg_size = 1 @@ -136,13 +203,17 @@ def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, nor global colossal_multihead_attention if colossal_multihead_attention is None: from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() colossal_multihead_attention = multihead_attention # create the layer in cuda kernels. cuda_module = colossal_multihead_attention - create_layer_func = (cuda_module.create_multihead_attention_fp16 - if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + create_layer_func = ( + cuda_module.create_multihead_attention_fp16 + if self.config.fp16 + else cuda_module.create_multihead_attention_fp32 + ) create_layer_func( self.config.layer_id, @@ -204,13 +275,15 @@ def reset_parameters(self): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size), :]) + attn_qkvw_global.view(3, hs, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : + ] + ) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + attn_qkvb_global.view(3, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) + ] + ) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -218,9 +291,9 @@ def reset_parameters(self): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) + self.out_proj_weight.copy_( + attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] + ) else: attn_qkvw = self.in_proj_weight.view(-1, hs) @@ -238,7 +311,7 @@ def forward(self, hidden_states, encoder_padding_mask): self.config.training = self.training self.config.is_grad_enabled = torch.is_grad_enabled() hidden_states = hidden_states.contiguous() - encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() bs, sl, dim = hidden_states.size() if bs * sl > self.config.max_batch_tokens: @@ -250,8 +323,16 @@ def forward(self, hidden_states, encoder_padding_mask): else: assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, - self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, - self.norm_weight, self.norm_bias, self.config) + output = MultiHeadAttention1DFunc.apply( + hidden_states, + encoder_padding_mask, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.norm_weight, + self.norm_bias, + self.config, + ) return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 41cd4b20faa1..26a5bce16d5c 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -108,15 +108,16 @@ def __init__( super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 - and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -130,13 +131,14 @@ def forward(self, input, mask): def is_kernel_available(self, mask, b, np, sq, sk): attn_batches = b * np - if (self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 57b8fb7b2e99..67a147cd581c 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,8 +1,10 @@ -from .option import set_jit_fusion_options -from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl +from .option import set_jit_fusion_options __all__ = [ - "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", - "set_jit_fusion_options" + "bias_dropout_add_fused_train", + "bias_dropout_add_fused_inference", + "bias_gelu_impl", + "set_jit_fusion_options", ] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py index 32965c1ebd69..e046ee2964af 100644 --- a/colossalai/kernel/jit/bias_dropout_add.py +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -1,5 +1,4 @@ import torch -from torch import Tensor def bias_dropout_add(x, bias, residual, prob, training): @@ -10,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training): @torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_train( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_inference( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py index 33b4ac32b044..5fa0d07015be 100644 --- a/colossalai/kernel/jit/bias_gelu.py +++ b/colossalai/kernel/jit/bias_gelu.py @@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y): class GeLUFunction(torch.autograd.Function): - @staticmethod # bias is an optional argument def forward(ctx, input, bias): diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8eb4e0c880a0..8bebad894ca4 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -10,15 +10,14 @@ def set_jit_fusion_options(): - """Set PyTorch JIT layer fusion options. - """ + """Set PyTorch JIT layer fusion options.""" # LSG: the latest pytorch and CUDA versions may not support # the following jit settings global JIT_OPTIONS_SET if JIT_OPTIONS_SET == False: # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): # nvfuser torch._C._jit_set_profiling_executor(True) @@ -38,12 +37,14 @@ def set_jit_fusion_options(): JIT_OPTIONS_SET = True -def warmup_jit_fusion(batch_size: int, - hidden_size: int, - seq_length: int = 512, - vocab_size: int = 32768, - dtype: torch.dtype = torch.float32): - """ Compile JIT functions before the main training steps """ +def warmup_jit_fusion( + batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32, +): + """Compile JIT functions before the main training steps""" embed = Embedding(vocab_size, hidden_size).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 75812db036a9..bc68a07e6fba 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,5 +1,6 @@ try: import triton + HAS_TRITON = True from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd @@ -11,8 +12,14 @@ from .token_attention_kernel import token_attention_fwd __all__ = [ - "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", - "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" + "llama_context_attn_fwd", + "bloom_context_attn_fwd", + "softmax", + "layer_norm", + "rmsnorm_forward", + "copy_kv_cache_to_dest", + "rotary_embedding_fwd", + "token_attention_fwd", ] except ImportError: diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 38db2048c6a4..dac95bfb14ae 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -1,8 +1,11 @@ -import torch import math + +import torch + try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -10,28 +13,42 @@ if HAS_TRITON: - ''' - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 - ''' + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + @triton.jit def _context_flash_attention_kernel( - Q, K, V, sm_scale, - B_Start_Loc, B_Seqlen, - TMP, + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, alibi_ptr, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): - batch_id = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -40,13 +57,18 @@ def _context_flash_attention_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info + + # get batch info cur_batch_seq_len = tl.load(B_Seqlen + batch_id) cur_batch_start_index = tl.load(B_Start_Loc + batch_id) block_start_loc = BLOCK_M * start_m - - load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd @@ -56,7 +78,7 @@ def _context_flash_attention_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - + if alibi_ptr is not None: alibi_m = tl.load(alibi_ptr + cur_head) @@ -64,8 +86,11 @@ def _context_flash_attention_kernel( for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -95,21 +120,25 @@ def _context_flash_attention_kernel( acc_scale = tl.load(t_ptrs) acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new - - off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return - - + @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): BLOCK = 128 @@ -129,17 +158,31 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, - b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, alibi, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -147,7 +190,7 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_stages=1, ) return - + @torch.no_grad() def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 128 @@ -166,19 +209,34 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, None, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index c1eaa8a10ed1..02edcc9a903a 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -3,25 +3,28 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") if HAS_TRITON: + @triton.jit def _fwd_copy_kv_cache_dest( - kv_cache_ptr, dest_index_ptr, + kv_cache_ptr, + dest_index_ptr, out, - stride_k_bs, - stride_k_h, + stride_k_bs, + stride_k_h, stride_k_d, - stride_o_bs, - stride_o_h, + stride_o_bs, + stride_o_h, stride_o_d, head_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr + BLOCK_HEAD: tl.constexpr, ): cur_index = tl.program_id(0) offs_h = tl.arange(0, BLOCK_HEAD) @@ -31,15 +34,14 @@ def _fwd_copy_kv_cache_dest( cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets - + o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] o_ptrs = out + dest_index * stride_o_bs + o_offsets k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return - - + @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] @@ -47,16 +49,18 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): head_dim = k_ptr.shape[2] assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" - + num_warps = 2 _fwd_copy_kv_cache_dest[(seq_len,)]( - k_ptr, dest_index_ptr, out, - k_ptr.stride(0), - k_ptr.stride(1), + k_ptr, + dest_index_ptr, + out, + k_ptr.stride(0), + k_ptr.stride(1), k_ptr.stride(2), - out.stride(0), - out.stride(1), + out.stride(0), + out.stride(1), out.stride(2), head_num, BLOCK_DMODEL=head_dim, @@ -65,5 +69,3 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): num_stages=2, ) return - - diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py index 99800acfbb92..24083b050808 100644 --- a/colossalai/kernel/triton/fused_layernorm.py +++ b/colossalai/kernel/triton/fused_layernorm.py @@ -3,6 +3,7 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -14,13 +15,13 @@ @triton.jit def _layer_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -32,15 +33,15 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -50,7 +51,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -71,13 +72,7 @@ def layer_norm(x, weight, bias, eps): # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, - y, - weight, - bias, - x_arg.stride(0), - N, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps) + _layer_norm_fwd_fused[(M,)]( + x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) return y diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index 62fc6bba0360..7b5cd2923f0e 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -1,7 +1,7 @@ -import torch try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -9,9 +9,10 @@ if HAS_TRITON: - ''' + """ this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - ''' + """ + @triton.jit def qkv_gemm_4d_kernel( a_ptr, @@ -34,12 +35,12 @@ def qkv_gemm_4d_kernel( stride_cn, scale, # Meta-parameters - BLOCK_SIZE_M : tl.constexpr = 64, - BLOCK_SIZE_N : tl.constexpr = 32, - BLOCK_SIZE_K : tl.constexpr = 32, - GROUP_SIZE_M : tl.constexpr = 8, + BLOCK_SIZE_M: tl.constexpr = 64, + BLOCK_SIZE_N: tl.constexpr = 32, + BLOCK_SIZE_K: tl.constexpr = 32, + GROUP_SIZE_M: tl.constexpr = 8, ): - r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) Args: a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) @@ -53,21 +54,21 @@ def qkv_gemm_4d_kernel( stride_bh(tl.constexpr): stride for h-dimention for tensor array B stride_bk(tl.constexpr): stride for k-dimention for tensor array B stride_bn(tl.constexpr): stride for n-dimention for tensor array B - stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output stride_ch(tl.constexpr): stride for h-dimention for tensor array output stride_cm(tl.constexpr): stride for m-dimention for tensor array output stride_cn(tl.constexpr): stride for n-dimention for tensor array output BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b BLOCK_SIZE_K : tiling size for K-dimension of a and b - GROUP_SIZE_M : group size for reducing cache miss, more details: + GROUP_SIZE_M : group size for reducing cache miss, more details: """ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - batch = tl.program_id(axis = 0) - head = tl.program_id(axis = 1) - pid = tl.program_id(axis = 2) + batch = tl.program_id(axis=0) + head = tl.program_id(axis=1) + pid = tl.program_id(axis=2) # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -77,33 +78,38 @@ def qkv_gemm_4d_kernel( pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + - (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) - b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + - (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + a_ptrs = ( + a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + ) + b_ptrs = ( + b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) - a = tl.load(a_ptrs, mask=a_mask, other=0.) - b = tl.load(b_ptrs, mask=b_mask, other=0.) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + accumulator = accumulator.to(c_ptr.dtype.element_ty) if scale > 0: accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) - offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + - stride_cn * offs_accumu_n[None, :]) + c_ptrs = ( + c_ptr + + batch * stride_cb + + head * stride_ch + + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :] + ) accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py index 1fb79115f8ce..d5d6f9d85df1 100644 --- a/colossalai/kernel/triton/rms_norm.py +++ b/colossalai/kernel/triton/rms_norm.py @@ -3,17 +3,19 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") - + if HAS_TRITON: - ''' - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - ''' + """ + this kernel function is modified from + https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py + """ + @triton.jit def _rms_norm_fwd_fused( X, # pointer to the input @@ -32,7 +34,7 @@ def _rms_norm_fwd_fused( _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -41,13 +43,12 @@ def _rms_norm_fwd_fused( cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = x * rstd y = x_hat * w # Write output tl.store(Y + cols, y.to(tl.float16), mask=mask) - def rmsnorm_forward(x, weight, eps): # allocate output y = torch.empty_like(x) @@ -66,7 +67,5 @@ def rmsnorm_forward(x, weight, eps): BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 num_warps = 8 # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py index d9d1b2bcf026..eb43fab7935c 100644 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -29,19 +29,29 @@ def _rotary_kernel( dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ - None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride - off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ - None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - q0 = tl.load(q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0) - q1 = tl.load(q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0) + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) @@ -49,12 +59,16 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store(q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) - tl.store(q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 6ae54dcb0b38..4b56c8afd67f 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -1,9 +1,8 @@ import torch -from torch import nn try: import triton - import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -13,9 +12,10 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel - def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - input_mask: torch.Tensor, scale: float): - r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + def self_attention_forward_without_fusion( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float + ): + r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels Args: q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) @@ -65,7 +65,7 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t score_output.stride(2), score_output.stride(3), scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting + # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, @@ -79,7 +79,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t n_rows, n_cols = score_output.shape if n_rows <= 350000: - block_size = max(triton.next_power_of_2(n_cols), 2) num_warps = 4 if block_size >= 4096: @@ -142,15 +141,9 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t ) return output.view(batches, -1, d_model) - def self_attention_compute_using_triton(qkv, - input_mask, - layer_past, - alibi, - scale, - head_size, - triangular=False, - use_flash=False): - + def self_attention_compute_using_triton( + qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False + ): assert qkv.is_contiguous() assert alibi is None, "current triton self-attention does not support alibi" batches = qkv.shape[0] @@ -158,8 +151,8 @@ def self_attention_compute_using_triton(qkv, num_of_heads = d_model // head_size q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py index c65adaf40dda..8ffce80a3041 100644 --- a/colossalai/kernel/triton/softmax.py +++ b/colossalai/kernel/triton/softmax.py @@ -1,39 +1,42 @@ import torch + try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - ''' - softmax kernel is modified based on + """ + softmax kernel is modified based on https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py - ''' + """ + @triton.jit def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): - r""" the kernel function for implementing softmax operator + r"""the kernel function for implementing softmax operator Args: output_ptr: the output after finishing softmax operation, (N, hidden_dim) input_ptr: the tensor of input, shape should be (N, hidden_dim) n_cols(tl.constexpr): the number of cols of input - BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim """ row_idx = tl.program_id(0) row_start_ptr = input_ptr + row_idx * row_stride col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) if mask_ptr is not None: - # load mask into SRAM + # load mask into SRAM mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) - # update + # update row_minus_max = row_minus_max + mask numerator = tl.exp(row_minus_max) @@ -43,17 +46,16 @@ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SI output_ptrs = output_row_start_ptr + col_offsets # Write back output to DRAM tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) - - + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: if mask is not None: assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" - assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" - + assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention" + hidden_dim = input.shape[-1] output = torch.empty_like(input) input = input.view(-1, hidden_dim) - if mask is not None: + if mask is not None: mask = mask.view(-1, hidden_dim) assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" @@ -67,30 +69,31 @@ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Ten else: num_warps = 4 - if num_rows <= 350000: + if num_rows <= 350000: grid = (num_rows,) - softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + softmax_kernel[grid]( + output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps + ) else: grid = lambda meta: () - grid = lambda meta: ( - triton.cdiv(num_rows, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),) - BLOCK_M = 32 if block_size >= 4096: - BLOCK_M = 4 + pass elif block_size >= 2048: - BLOCK_M = 8 + pass - softmax_kernel[grid](output_ptr = output, - input_ptr = input, - row_stride = input.stride(0), - n_rows = num_rows, - n_cols = num_cols, - mask_ptr = mask, - # currently manually setting up size - BLOCK_M = 32, - BLOCK_SIZE = block_size) + softmax_kernel[grid]( + output_ptr=output, + input_ptr=input, + row_stride=input.stride(0), + n_rows=num_rows, + n_cols=num_cols, + mask_ptr=mask, + # currently manually setting up size + BLOCK_M=32, + BLOCK_SIZE=block_size, + ) - return output \ No newline at end of file + return output diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c6b25f4abcec..7d0f9708516a 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -1,12 +1,12 @@ # Adapted from ModelTC https://github.com/ModelTC/lightllm -import math import torch try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -15,10 +15,28 @@ if HAS_TRITON: @triton.jit - def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, - q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, - attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): + def _token_attn_1_kernel( + Q, + K, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) start_n = tl.program_id(2) @@ -40,9 +58,11 @@ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_ca for start_mark in range(0, block_mask, 1): q = tl.load(Q + off_q + start_mark) offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -52,11 +72,29 @@ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_ca return @triton.jit - def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, - max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, - q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, - k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr): + def _token_attn_1_alibi_kernel( + Q, + K, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) start_n = tl.program_id(2) @@ -79,9 +117,11 @@ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_sta alibi_m = tl.load(alibi + current_head) q = tl.load(Q + off_q + start_mark) offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -92,14 +132,9 @@ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_sta return @torch.no_grad() - def token_attn_fwd_1(q, - k, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - alibi=None): + def token_attn_fwd_1( + q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + ): BLOCK = 32 # shape constraints q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] @@ -168,9 +203,17 @@ def token_attn_fwd_1(q, return @triton.jit - def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, - logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, - BLOCK_SIZE: tl.constexpr): + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) @@ -178,20 +221,26 @@ def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - row = tl.load(softmax_logics + current_head * logics_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float('inf')).to(tl.float32) + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(softmax_prob_out + current_head * prob_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len) + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) return @torch.no_grad() @@ -220,11 +269,27 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, return @triton.jit - def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, - v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, - attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr): + def _token_attn_2_kernel( + Prob, + V, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) @@ -232,7 +297,6 @@ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv offs_d = tl.arange(0, HEAD_DIM) current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = current_batch_seq_len current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride @@ -242,19 +306,29 @@ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv acc = tl.zeros([HEAD_DIM], dtype=tl.float32) for start_n in range(0, current_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) acc += tl.sum(p_value[:, None] * v_value, 0) acc = acc.to(tl.float16) - off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) out_ptrs = attn_out + off_o tl.store(out_ptrs, acc) return @@ -296,15 +370,9 @@ def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cac return @torch.no_grad() - def token_attention_fwd(q, - k, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=None): + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None + ): head_num = k.shape[1] batch_size = kv_cache_seq_len.shape[0] calcu_shape1 = (batch_size, head_num, k.shape[2]) @@ -312,21 +380,24 @@ def token_attention_fwd(q, att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - token_attn_fwd_1(q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi) + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) prob = torch.empty_like(att_m_tensor) token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, - max_len_in_batch) + token_attn_fwd_2( + prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch + ) prob = None diff --git a/colossalai/lazy/__init__.py b/colossalai/lazy/__init__.py index 4387107bf773..c6b813c50036 100644 --- a/colossalai/lazy/__init__.py +++ b/colossalai/lazy/__init__.py @@ -1,6 +1,6 @@ from .lazy_init import LazyInitContext, LazyTensor __all__ = [ - 'LazyInitContext', - 'LazyTensor', + "LazyInitContext", + "LazyTensor", ] diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index e071563c045a..ebaf2e1600fc 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from types import MethodType from typing import Callable, Dict, Optional, Union @@ -35,43 +34,43 @@ "eye", ] -_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] +_EARLY_MATERIALIZED_OPS = ["__getitem__", "split"] # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # These ops cannot be unwrapped using .data -_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__', 'numel', 'size', 'dim'] +_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] _LEGACY_TENSOR_CONSTRUCTOR = { - 'FloatTensor': torch.float, - 'DoubleTensor': torch.double, - 'HalfTensor': torch.half, - 'BFloat16Tensor': torch.bfloat16, - 'ByteTensor': torch.uint8, - 'CharTensor': torch.int8, - 'ShortTensor': torch.short, - 'IntTensor': torch.int, - 'LongTensor': torch.long, - 'BoolTensor': torch.bool, + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, } _EMPTY_DATA = torch.empty(0) class _MyTensor(Tensor): - """This class is only for correctness verification. - """ - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + """This class is only for correctness verification.""" + + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None - def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor": cls._pre_op_fn() if concrete_data is not None: # uniform api as LazyTensor data = concrete_data else: - kwargs['device'] = cls.default_device + kwargs["device"] = cls.default_device data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @@ -82,12 +81,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def _data_tolist(tensor: torch.Tensor) -> list: - """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. - """ + """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.""" return tensor.data.tolist() -def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: +def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. @@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: tensor.__class__ = cls_to_become if cls_to_become is Parameter: # to fit UninitializedParameter - delattr(tensor, '_is_param') + delattr(tensor, "_is_param") tensor.data = target tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method @@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor): """ _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None @@ -159,8 +157,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): elem = concrete_data else: if meta_data is None: - device = kwargs.get('device', 'cpu') - elem = func(*args, **{**kwargs, 'device': 'meta'}) + device = kwargs.get("device", "cpu") + elem = func(*args, **{**kwargs, "device": "meta"}) meta_data = MetaTensor(elem, device=device) elem = meta_data._tensor # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here @@ -170,10 +168,10 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): if func.__name__ in _NORMAL_FACTORY: - kwargs = {**kwargs, 'device': LazyTensor.default_device} - self._factory_method = (func, args, kwargs) # (func, args, kwargs) - self._op_buffer = [] # (func, args, kwargs, replace) - self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + kwargs = {**kwargs, "device": LazyTensor.default_device} + self._factory_method = (func, args, kwargs) # (func, args, kwargs) + self._op_buffer = [] # (func, args, kwargs, replace) + self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). @@ -200,12 +198,11 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to return _convert_cls(self, local_tensor) def clean(self) -> None: - """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. - """ - delattr(self, '_factory_method') - delattr(self, '_op_buffer') - delattr(self, '_materialized_data') - delattr(self, '_meta_data') + """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" + delattr(self, "_factory_method") + delattr(self, "_op_buffer") + delattr(self, "_materialized_data") + delattr(self, "_meta_data") @staticmethod def _replace_with_materialized(x): @@ -221,8 +218,9 @@ def _materialize_data(self) -> torch.Tensor: # apply cached sequence self._pre_op_fn() - init_val = func(*tree_map(self._replace_with_materialized, args), - **tree_map(self._replace_with_materialized, kwargs)) + init_val = func( + *tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs) + ) self._materialized_data = self._rerun_ops(init_val) return self._materialized_data @@ -243,13 +241,13 @@ def replace(x): packed = None - for (func, args, kwargs) in self._op_buffer: + for func, args, kwargs in self._op_buffer: if func == torch.Tensor.requires_grad_: - packed = func, args, kwargs # requires grad should be set at last + packed = func, args, kwargs # requires grad should be set at last else: self._pre_op_fn() o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: @@ -268,8 +266,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # These OPs cannot be lazy and related tensors should be early materialized tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, kwargs) - is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) - or func.__name__ in ('__setitem__', '__set__')) + is_inplace: bool = ( + func.__name__.endswith("_") + and not (func.__name__.endswith("__")) + or func.__name__ in ("__setitem__", "__set__") + ) is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS @@ -285,11 +286,11 @@ def unwrap(x): target: LazyTensor = args[0].clone() target._op_buffer.append((func, args, kwargs)) - target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), - **tree_map(unwrap, kwargs)) + target._meta_data = getattr(target._meta_data, func.name)( + *tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs) + ) return target else: - meta_to_lazy = {} def unwrap(x): @@ -328,10 +329,9 @@ def wrap(y, i=None): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip + pass # skip def clone(self) -> "LazyTensor": - def factory_fn(): # if self is materialized, return self new_tensor = self.materialize() if type(self) is LazyTensor else self @@ -346,8 +346,10 @@ def detach(self) -> Tensor: def __deepcopy__(self, memo): if not self.is_leaf: - raise RuntimeError("Only Tensors created explicitly by the user " - "(graph leaves) support the deepcopy protocol at the moment") + raise RuntimeError( + "Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment" + ) if id(self) in memo: return memo[id(self)] @@ -375,7 +377,7 @@ def data(self): return self @data.setter - def data(self, other: 'LazyTensor'): + def data(self, other: "LazyTensor"): """This is sightly different from oringinal `data` setter. E.g.: @@ -413,7 +415,7 @@ def __hash__(self): def __rpow__(self, other): dtype = torch.result_type(self, other) - return torch.tensor(other, dtype=dtype, device=self.device)**self + return torch.tensor(other, dtype=dtype, device=self.device) ** self class LazyInitContext: @@ -444,11 +446,14 @@ class LazyInitContext: 1. Quantization strategies can be applied before allocating real memory. 2. Lazy initialization seems slower than normal initialization. """ + _replaced: bool = False - def __init__(self, - tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, - default_device: Optional[Union[torch.device, str, int]] = None): + def __init__( + self, + tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, + default_device: Optional[Union[torch.device, str, int]] = None, + ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor self.overrides = {} self.tensor_cls = tensor_cls @@ -457,7 +462,7 @@ def __init__(self, def __enter__(self): if LazyInitContext._replaced: - raise RuntimeError(f'LazyInitContext is not reentrant') + raise RuntimeError(f"LazyInitContext is not reentrant") LazyInitContext._replaced = True self.old_default_device = self.tensor_cls.default_device self.tensor_cls.default_device = self.default_device @@ -485,17 +490,17 @@ def wrapper(*args, **kwargs): return args[0] elif len(args) == 1: # (object data, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['tensor'] + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["tensor"] return replaced(*args, **kwargs) elif _is_int_tuple(args): # (tuple of ints size, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['empty'] + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["empty"] return replaced(*args, **kwargs) else: raise TypeError( - f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' + f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)" ) return wrapper, target @@ -514,23 +519,29 @@ def wrapper(*args, **kwargs): if callable(getattr(torch, target, None)) } - self.overrides.update({ - target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) - for target in _NORMAL_FACTORY - if callable(getattr(torch, target + '_like', None)) - }) - - self.overrides.update({ - target: wrap_legacy_constructor(getattr(torch, target), dtype) - for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() - if callable(getattr(torch, target, None)) - }) - - self.overrides.update({ - target: wrap_no_meta_factory(getattr(torch, target)) - for target in _NO_META_FACTORY - if callable(getattr(torch, target, None)) - }) + self.overrides.update( + { + target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target + "_like", None)) + } + ) + + self.overrides.update( + { + target: wrap_legacy_constructor(getattr(torch, target), dtype) + for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() + if callable(getattr(torch, target, None)) + } + ) + + self.overrides.update( + { + target: wrap_no_meta_factory(getattr(torch, target)) + for target in _NO_META_FACTORY + if callable(getattr(torch, target, None)) + } + ) for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, wrapper) @@ -556,10 +567,9 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, - device_mesh: DeviceMesh, - sharding_spec_dict: Dict[str, ShardingSpec], - verbose: bool = False) -> nn.Module: + def distribute( + module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False + ) -> nn.Module: """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -574,9 +584,9 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) -def _apply_to_lazy_module(module: nn.Module, - apply_fn: Callable[[str, torch.Tensor], None], - verbose: bool = False) -> nn.Module: +def _apply_to_lazy_module( + module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False +) -> nn.Module: if verbose: # verbose info param_cnt = 0 @@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module, if verbose: param_cnt += 1 total_numel += p.numel() - if getattr(p, '_materialized_data', False) is None: + if getattr(p, "_materialized_data", False) is None: # if no _materialized_data attr, the tensor is not lazy param_lazy_cnt += 1 else: @@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module, if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') - _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') + _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") + _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") _print_rank_0( - f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" + ) return module diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py index f51941ee800b..4d6ad357a2fa 100644 --- a/colossalai/legacy/__init__.py +++ b/colossalai/legacy/__init__.py @@ -1,9 +1,9 @@ from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch __all__ = [ - 'launch', - 'launch_from_openmpi', - 'launch_from_slurm', - 'launch_from_torch', - 'initialize', + "launch", + "launch_from_openmpi", + "launch_from_slurm", + "launch_from_torch", + "initialize", ] diff --git a/colossalai/legacy/amp/__init__.py b/colossalai/legacy/amp/__init__.py index e83a7f6ac5cd..9d17d88b4c79 100644 --- a/colossalai/legacy/amp/__init__.py +++ b/colossalai/legacy/amp/__init__.py @@ -12,7 +12,7 @@ from .naive_amp import convert_to_naive_amp from .torch_amp import convert_to_torch_amp -__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] +__all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"] def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): @@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod For ``torch_amp``, please check `torch_amp config `_. """ - assert isinstance(mode, AMP_TYPE), \ - f'expected the argument mode be AMP_TYPE, but got {type(mode)}' + assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}" if amp_config is None: amp_config = Config() diff --git a/colossalai/legacy/amp/amp_type.py b/colossalai/legacy/amp/amp_type.py index 6f322f866cfc..5ad5faf08b71 100644 --- a/colossalai/legacy/amp/amp_type.py +++ b/colossalai/legacy/amp/amp_type.py @@ -5,6 +5,6 @@ class AMP_TYPE(Enum): - APEX = 'apex' - TORCH = 'torch' - NAIVE = 'naive' + APEX = "apex" + TORCH = "torch" + NAIVE = "naive" diff --git a/colossalai/legacy/amp/apex_amp/__init__.py b/colossalai/legacy/amp/apex_amp/__init__.py index 51b9b97dccce..680c6e45ca9d 100644 --- a/colossalai/legacy/amp/apex_amp/__init__.py +++ b/colossalai/legacy/amp/apex_amp/__init__.py @@ -34,9 +34,10 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): More details about ``amp_config`` refer to `amp_config `_. """ import apex.amp as apex_amp + model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) optimizer = ApexAMPOptimizer(optimizer) return model, optimizer -__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] +__all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"] diff --git a/colossalai/legacy/amp/apex_amp/apex_amp.py b/colossalai/legacy/amp/apex_amp/apex_amp.py index acc051181562..048c51891b17 100644 --- a/colossalai/legacy/amp/apex_amp/apex_amp.py +++ b/colossalai/legacy/amp/apex_amp/apex_amp.py @@ -15,7 +15,7 @@ class ApexAMPOptimizer(OptimizerWrapper): - """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm + """A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods """ diff --git a/colossalai/legacy/amp/naive_amp/__init__.py b/colossalai/legacy/amp/naive_amp/__init__.py index 2ee84fc763b1..36e402299147 100644 --- a/colossalai/legacy/amp/naive_amp/__init__.py +++ b/colossalai/legacy/amp/naive_amp/__init__.py @@ -41,7 +41,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): output_to_fp32 = is_no_pp_or_last_stage() model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) - use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + use_dynamic_grad_scaler = amp_config.pop("dynamic_grad_scale", True) if use_dynamic_grad_scaler: scaler_class = DynamicGradScaler else: @@ -57,4 +57,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): return model, optimizer -__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] +__all__ = ["convert_to_naive_amp", "NaiveAMPOptimizer", "FP16Optimizer"] diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 2733477599f7..97ec57fbd007 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -21,7 +21,7 @@ except: fused_optim = None -__all__ = ['FP16Optimizer'] +__all__ = ["FP16Optimizer"] def load_fused_optim(): @@ -63,13 +63,15 @@ class FP16Optimizer(Optimizer): verbose (bool, optional): if set to `True`, will print debug info. Default False. """ - def __init__(self, - optimizer: Optimizer, - grad_scaler: BaseGradScaler, - verbose: bool = False, - clip_grad_norm=0, - dp_process_group: ProcessGroup = None, - mp_process_group: ProcessGroup = None): + def __init__( + self, + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None, + ): # have a defaults for compatibility with pytorch optim self._optimizer = optimizer self._defaults = optimizer.defaults @@ -117,10 +119,10 @@ def _get_process_group(parallel_mode): fp32_master_params = [] fp32_params = [] # For all the parameters in this group: - for i, param in enumerate(param_group['params']): + for i, param in enumerate(param_group["params"]): if param.requires_grad: # float16 params: - if param.type() in ['torch.cuda.HalfTensor']: + if param.type() in ["torch.cuda.HalfTensor"]: fp16_params.append(param) # Create a fp32 copy @@ -129,7 +131,7 @@ def _get_process_group(parallel_mode): copy_tensor_parallel_attributes(param, fp32_param) # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = fp32_param + param_group["params"][i] = fp32_param fp32_master_params.append(fp32_param) # Reset existing state dict key to the new main param. @@ -137,11 +139,13 @@ def _get_process_group(parallel_mode): self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) # fp32 params. - elif param.type() == 'torch.cuda.FloatTensor': + elif param.type() == "torch.cuda.FloatTensor": fp32_params.append(param) else: - raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' - f'or torch.cuda.HalfTensor, but got {param.type()}') + raise TypeError( + "Expected parameter of type torch.cuda.FloatTensor " + f"or torch.cuda.HalfTensor, but got {param.type()}" + ) self._fp16_param_groups.append(fp16_params) self._fp32_master_param_groups.append(fp32_master_params) @@ -160,12 +164,12 @@ def _get_process_group(parallel_mode): f"clip_grad_norm = {clip_grad_norm}\n" f"grad_scaler = {self._grad_scaler.__class__.__name__}" f"==========================================", - ranks=[0]) + ranks=[0], + ) @property def max_norm(self): - """Returns the maximum norm of gradient clipping. - """ + """Returns the maximum norm of gradient clipping.""" return self._clip_grad_max_norm @property @@ -211,7 +215,7 @@ def _check_overflow(self): # check for overflow for group in self._optimizer.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None and has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) break @@ -235,7 +239,7 @@ def zero_grad(self, set_to_none=True): # set_to_none = True can save some memory space for param_group in self._optimizer.param_groups: - zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + zero_gard_by_list(param_group["params"], set_to_none=set_to_none) def _get_fp32_param_groups_to_update(self): return self._fp32_master_param_groups + self._fp32_param_groups @@ -262,13 +266,12 @@ def _update_fp16_param_from_fp32_param(self): for fp16_param, fp32_param in zip(fp16_group, fp32_group): fp16_param_data.append(fp16_param.data) fp32_master_param_data.append(fp32_param.data) - _multi_tensor_copy_this_to_that(this=fp32_master_param_data, - that=fp16_param_data, - overflow_buf=self._dummy_overflow_buf) + _multi_tensor_copy_this_to_that( + this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf + ) def step(self): - """Update the model parameters. - """ + """Update the model parameters.""" # Copy gradients from model params to main params. self._assign_grad_to_fp32_master_param() @@ -307,14 +310,13 @@ def backward(self, loss): scaled_loss.backward() def state_dict(self): - """Returns the states of the fp16 optimizer as a dict object. - """ + """Returns the states of the fp16 optimizer as a dict object.""" state_dict = {} - state_dict['optimizer'] = self._optimizer.state_dict() + state_dict["optimizer"] = self._optimizer.state_dict() if self.grad_scaler: - state_dict['grad_scaler'] = self.grad_scaler.state_dict() - state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups + state_dict["grad_scaler"] = self.grad_scaler.state_dict() + state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups return state_dict def load_state_dict(self, state_dict): @@ -325,16 +327,17 @@ def load_state_dict(self, state_dict): """ # Optimizer. - self._optimizer.load_state_dict(state_dict['optimizer']) + self._optimizer.load_state_dict(state_dict["optimizer"]) # Grad scaler. - if 'grad_scaler' in state_dict: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + if "grad_scaler" in state_dict: + self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) # Copy data for the main params. - if 'fp32_master_param_groups' in state_dict: - for current_group, ckpt_group in zip(self._fp32_master_param_groups, - state_dict['fp32_master_param_groups']): + if "fp32_master_param_groups" in state_dict: + for current_group, ckpt_group in zip( + self._fp32_master_param_groups, state_dict["fp32_master_param_groups"] + ): for current_param, ckpt_param in zip(current_group, ckpt_group): current_param.data.copy_(ckpt_param.data) @@ -346,7 +349,7 @@ def clip_grad_norm(self, clip_grad): """ params = [] for param_group in self._optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: params.append(param) return clip_grad_norm_fp32(params, clip_grad) diff --git a/colossalai/legacy/amp/naive_amp/_utils.py b/colossalai/legacy/amp/naive_amp/_utils.py index 7633705e19fb..aa5a91146bb0 100644 --- a/colossalai/legacy/amp/naive_amp/_utils.py +++ b/colossalai/legacy/amp/naive_amp/_utils.py @@ -27,7 +27,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: return True return False diff --git a/colossalai/legacy/amp/naive_amp/naive_amp.py b/colossalai/legacy/amp/naive_amp/naive_amp.py index 1fab3e5a0d0d..f9c298941fa9 100644 --- a/colossalai/legacy/amp/naive_amp/naive_amp.py +++ b/colossalai/legacy/amp/naive_amp/naive_amp.py @@ -45,9 +45,11 @@ def step(self): def clip_grad_norm(self, model: nn.Module, max_norm: float): if self.optim.max_norm == max_norm: return - raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " - "If you have supplied clip_grad_norm in the amp_config, " - "executing the method clip_grad_norm is not allowed.") + raise RuntimeError( + "NaiveAMP optimizer has clipped gradients during optimizer.step(). " + "If you have supplied clip_grad_norm in the amp_config, " + "executing the method clip_grad_norm is not allowed." + ) class NaiveAMPModel(nn.Module): @@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module): in `parallel_mode `_. """ - def __init__(self, - model: nn.Module, - output_to_fp32: bool = True, - parallel_mode: ParallelMode = ParallelMode.DATA, - sync_buffer: bool = True): + def __init__( + self, + model: nn.Module, + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True, + ): super().__init__() self.model = model.half() self._output_to_fp32 = output_to_fp32 diff --git a/colossalai/legacy/amp/torch_amp/__init__.py b/colossalai/legacy/amp/torch_amp/__init__.py index 893cc890d68e..ad2416eef06a 100644 --- a/colossalai/legacy/amp/torch_amp/__init__.py +++ b/colossalai/legacy/amp/torch_amp/__init__.py @@ -9,10 +9,9 @@ from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer -def convert_to_torch_amp(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - amp_config: Optional[Config] = None): +def convert_to_torch_amp( + model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, amp_config: Optional[Config] = None +): """A helper function to wrap training components with Pytorch AMP modules Args: @@ -42,4 +41,4 @@ def convert_to_torch_amp(model: nn.Module, return model, optimizer, criterion -__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] +__all__ = ["convert_to_torch_amp", "TorchAMPModel", "TorchAMPLoss", "TorchAMPOptimizer"] diff --git a/colossalai/legacy/amp/torch_amp/_grad_scaler.py b/colossalai/legacy/amp/torch_amp/_grad_scaler.py index 543dac6ab5ef..fc1aeec234fd 100644 --- a/colossalai/legacy/amp/torch_amp/_grad_scaler.py +++ b/colossalai/legacy/amp/torch_amp/_grad_scaler.py @@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object): """ def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + assert master_tensor.is_cuda or master_tensor.device.type == "xla" self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} @@ -118,7 +118,7 @@ class GradScaler(object): invokes the underlying ``optimizer.step()``, and other methods become no-ops. """ - def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): if enabled and not torch.cuda.is_available(): warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") self._enabled = False @@ -174,7 +174,7 @@ def scale(self, outputs): # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda or outputs.device.type == 'xla' + assert outputs.is_cuda or outputs.device.type == "xla" if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None @@ -186,7 +186,7 @@ def scale(self, outputs): def apply_scale(val): if isinstance(val, torch.Tensor): - assert val.is_cuda or val.device.type == 'xla' + assert val.is_cuda or val.device.type == "xla" if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) @@ -214,7 +214,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] with torch.no_grad(): for group in optimizer.param_groups: for param in group["params"]: @@ -238,8 +238,9 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), - per_device_inv_scale.get(device)) + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) + ) # For tensor parallel parameters it should be all-reduced over tensor parallel process group if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: vals = [val for val in per_device_found_inf._per_device_tensors.values()] @@ -328,7 +329,7 @@ def step(self, optimizer, *args, **kwargs): .. warning:: Closure use is not currently supported. """ - if (not self._enabled): + if not self._enabled: return optimizer.step(*args, **kwargs) if "closure" in kwargs: @@ -343,7 +344,7 @@ def step(self, optimizer, *args, **kwargs): retval = None - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling: # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # The contract with custom optimizers is that their step() should accept an additional, # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: @@ -391,14 +392,14 @@ def update(self, new_scale=None): if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] + self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." # type: ignore[attr-defined] assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] + self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. @@ -416,11 +417,23 @@ def update(self, new_scale=None): found_inf_combined += found_infs[i] if self._higher_than_torch18: - torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) else: - self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + self._scale = torch._amp_update_scale( + _growth_tracker, + _scale, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @@ -507,13 +520,17 @@ def state_dict(self): If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """ - return { - "scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker() - } if self._enabled else {} + return ( + { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + if self._enabled + else {} + ) def load_state_dict(self, state_dict): r""" @@ -526,8 +543,10 @@ def load_state_dict(self, state_dict): return if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) self._init_scale = state_dict["scale"] if self._scale is not None: @@ -542,15 +561,17 @@ def load_state_dict(self, state_dict): def __getstate__(self): state = self.__dict__.copy() if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) # Pickling _scale and _growth_tracker Tensors directly triggers # "warnings.warn("pickle support for Storage will be removed in 1.5..." # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None return state def __setstate__(self, state): @@ -562,8 +583,9 @@ def _check_inf_per_device(self, optimizer): dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_( + optimizer, dummy_inv_scale, found_inf, True + ) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index c45a5956a205..ced5cc3e6647 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -42,8 +42,7 @@ def backward(self, loss: Tensor): self.scaler.scale(loss).backward() def step(self): - """Update the parameters of the model - """ + """Update the parameters of the model""" self.scaler.step(self.optim) self.scaler.update() diff --git a/colossalai/legacy/builder/__init__.py b/colossalai/legacy/builder/__init__.py index cf09e1e7a31a..9af3d139d3b1 100644 --- a/colossalai/legacy/builder/__init__.py +++ b/colossalai/legacy/builder/__init__.py @@ -1,3 +1,3 @@ from .builder import build_from_config, build_from_registry, build_gradient_handler -__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] +__all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"] diff --git a/colossalai/legacy/builder/builder.py b/colossalai/legacy/builder/builder.py index ff14f46dc61f..dec3bc1c2487 100644 --- a/colossalai/legacy/builder/builder.py +++ b/colossalai/legacy/builder/builder.py @@ -19,7 +19,7 @@ def build_from_config(module, config: dict): AssertionError: Raises an AssertionError if `module` is not a class """ - assert inspect.isclass(module), 'module must be a class' + assert inspect.isclass(module), "module must be a class" return module(**config) @@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry): Raises: Exception: Raises an Exception if an error occurred when building from registry. """ - config_ = config.copy() # keep the original config untouched - assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' + config_ = config.copy() # keep the original config untouched + assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}" - mod_type = config_.pop('type') - assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' + mod_type = config_.pop("type") + assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}" try: obj = registry.get_module(mod_type)(**config_) except Exception as e: - print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True) raise e return obj @@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer): An object of :class:`colossalai.legacy.engine.BaseGradientHandler` """ config_ = config.copy() - config_['model'] = model - config_['optimizer'] = optimizer + config_["model"] = model + config_["optimizer"] = optimizer return build_from_registry(config_, GRADIENT_HANDLER) diff --git a/colossalai/legacy/communication/__init__.py b/colossalai/legacy/communication/__init__.py index 88ad0487b785..f4492b074425 100644 --- a/colossalai/legacy/communication/__init__.py +++ b/colossalai/legacy/communication/__init__.py @@ -14,21 +14,21 @@ from .utils import recv_obj_meta, send_obj_meta __all__ = [ - 'all_gather', - 'reduce_scatter', - 'all_reduce', - 'broadcast', - 'reduce', - 'send_forward', - 'send_forward_recv_forward', - 'send_forward_backward_recv_forward_backward', - 'send_backward', - 'send_backward_recv_backward', - 'send_backward_recv_forward', - 'send_forward_recv_backward', - 'recv_backward', - 'recv_forward', - 'ring_forward', - 'send_obj_meta', - 'recv_obj_meta', + "all_gather", + "reduce_scatter", + "all_reduce", + "broadcast", + "reduce", + "send_forward", + "send_forward_recv_forward", + "send_forward_backward_recv_forward_backward", + "send_backward", + "send_backward_recv_backward", + "send_backward_recv_forward", + "send_forward_recv_backward", + "recv_backward", + "recv_forward", + "ring_forward", + "send_obj_meta", + "recv_obj_meta", ] diff --git a/colossalai/legacy/communication/collective.py b/colossalai/legacy/communication/collective.py index 7471188226f0..9cf30f733dee 100644 --- a/colossalai/legacy/communication/collective.py +++ b/colossalai/legacy/communication/collective.py @@ -9,10 +9,10 @@ from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -_all_gather_func = dist._all_gather_base \ - if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor -_reduce_scatter_func = dist._reduce_scatter_base \ - if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +_all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = ( + dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +) def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: @@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: return out -def reduce_scatter(tensor: Tensor, - dim: int, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def reduce_scatter( + tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. @@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor, return out -def all_reduce(tensor: Tensor, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def all_reduce( + tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. Note: @@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s if dist.distributed_c10d._rank_not_in_group(group): return - if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") # set tensor device to cuda if backend is nccl - device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu") - my_rank = dist.get_rank() # use global rank + my_rank = dist.get_rank() # use global rank if my_rank == src: tensor_list, tensor_sizes = zip( - *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list] + ) tensor_list = list(map(lambda x: x.to(device), tensor_list)) tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index e3f9108ab840..19c3919b6e29 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): ops_queue.append(op_to_add) -def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, - object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, - recv_prev: bool = False, - recv_next: bool = False, - recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, - recv_next_shape: Union[torch.Size, List[torch.Size]] = None, - prev_rank: int = None, - next_rank: int = None, - dtype: torch.dtype = None, - scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: +def _communicate( + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """ Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other @@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non if recv_prev: assert recv_prev_shape is not None - tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, - scatter_gather_tensors) + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( + recv_prev_shape, dtype, scatter_gather_tensors + ) if recv_next: assert recv_next_shape is not None - tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, - scatter_gather_tensors) + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( + recv_next_shape, dtype, scatter_gather_tensors + ) if object_send_prev is not None or recv_prev: if prev_rank is None: @@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() else: for index in range(len(tensor_recv_prev)): - tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( - recv_prev_shape[index]).requires_grad_() + tensor_recv_prev[index] = ( + gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() + ) if recv_next and recv_next_split: if isinstance(tensor_recv_next, torch.Tensor): tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() else: for index in range(len(tensor_recv_next)): - tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( - recv_next_shape[index]).requires_grad_() + tensor_recv_next[index] = ( + gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() + ) return tensor_recv_prev, tensor_recv_next -def recv_forward(input_tensor_shape, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def recv_forward( + input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape, if gpc.is_pipeline_first_stage(): input_tensor = None else: - input_tensor, _ = _communicate(recv_prev=True, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + recv_prev=True, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def recv_backward(output_grad_shape, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def recv_backward( + output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -224,11 +230,13 @@ def recv_backward(output_grad_shape, if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: - _, output_tensor_grad = _communicate(recv_next=True, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + recv_next=True, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad @@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals prev_rank (int, optional): The rank of the recipient of the tensor """ if not gpc.is_pipeline_first_stage(): - _communicate(object_send_prev=input_tensor_grad, - prev_rank=prev_rank, - scatter_gather_tensors=scatter_gather_tensors) + _communicate( + object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors + ) -def send_forward_recv_backward(output_tensor, - output_grad_shape, - recv_next=True, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_forward_recv_backward( + output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor, if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: - _, output_tensor_grad = _communicate(object_send_next=output_tensor, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + object_send_next=output_tensor, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad -def send_backward_recv_forward(input_tensor_grad, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_backward_recv_forward( + input_tensor_grad, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad, if gpc.is_pipeline_first_stage(): input_tensor = None else: - input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def send_forward_recv_forward(output_tensor, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_forward_recv_forward( + output_tensor, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor, Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. """ - input_tensor, _ = _communicate(object_send_next=output_tensor, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + object_send_next=output_tensor, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def send_backward_recv_backward(input_tensor_grad, - output_grad_shape, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_backward_recv_backward( + input_tensor_grad, + output_grad_shape, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the gradient tensor from the next member in pipeline as the input of this stage. @@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad, Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. """ - _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + object_send_prev=input_tensor_grad, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad def send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - input_tensor_shape, - output_grad_shape, - recv_prev=True, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + output_tensor, + input_tensor_grad, + input_tensor_shape, + output_grad_shape, + recv_prev=True, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline and the gradient tensor to the previous stage, while receives the input gradient tensor from the next stage and the input tensor from the previous stage. @@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward( Returns: Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) """ - input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, - object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - recv_prev_shape=input_tensor_shape, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, output_tensor_grad = _communicate( + object_send_next=output_tensor, + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor, output_tensor_grad diff --git a/colossalai/legacy/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py index 66af214950f2..7c8d8bede069 100644 --- a/colossalai/legacy/communication/p2p_v2.py +++ b/colossalai/legacy/communication/p2p_v2.py @@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - Any: object after unpickled """ buf = tensor.numpy().tobytes()[:tensor_size] - if b'cuda' in buf: + if b"cuda" in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() - buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index buf = bytes(buf_array) io_bytes = io.BytesIO(buf) @@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank == src: object_tensor = torch.cat(tensor_list) else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, ) @@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) # unconsistence in device - if isinstance(unpickle_object, - torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index e80192fb578d..a61dae56cd42 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> ops = [] current_rank = gpc.get_global_rank() - tensor_recv_prev = torch.empty(buffer_shape, - requires_grad=True, - device=get_current_device(), - dtype=tensor_send_next.dtype) + tensor_recv_prev = torch.empty( + buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + ) # send to next rank - send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, - gpc.get_next_global_rank(parallel_mode)) + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode) + ) ops.append(send_next_op) # receive from prev rank - recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, - gpc.get_prev_global_rank(parallel_mode)) + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode) + ) ops.append(recv_prev_op) if current_rank % 2 == 0: diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 7e3dcf1e9820..6d77f3753fe8 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: @@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: numel = torch.numel(tensor) numel_gathered = world_size * numel gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) return gathered diff --git a/colossalai/legacy/constants.py b/colossalai/legacy/constants.py index 6cf9085f9fbb..5d64b676e73d 100644 --- a/colossalai/legacy/constants.py +++ b/colossalai/legacy/constants.py @@ -1,32 +1,32 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] -TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' +ALLOWED_MODES = [None, "1d", "2d", "2.5d", "3d", "sequence"] +TENSOR_PARALLEL_MODE = "tensor_parallel_mode" # initializer INITIALIZER_MAPPING = { - 'data': 'Initializer_Data', - 'tensor': 'Initializer_Tensor', - 'pipeline': 'Initializer_Pipeline', - 'embedding': 'Initializer_Embedding', - '1d': 'Initializer_1D', - '2d': 'Initializer_2D', - '2.5d': 'Initializer_2p5D', - '3d': 'Initializer_3D', - 'sequence': 'Initializer_Sequence', - 'model': 'Initializer_Model', - 'moe': 'Initializer_Moe' + "data": "Initializer_Data", + "tensor": "Initializer_Tensor", + "pipeline": "Initializer_Pipeline", + "embedding": "Initializer_Embedding", + "1d": "Initializer_1D", + "2d": "Initializer_2D", + "2.5d": "Initializer_2p5D", + "3d": "Initializer_3D", + "sequence": "Initializer_Sequence", + "model": "Initializer_Model", + "moe": "Initializer_Moe", } # 3D parallelism groups -INPUT_GROUP_3D = 'input_group_3d' -WEIGHT_GROUP_3D = 'weight_group_3d' -OUTPUT_GROUP_3D = 'output_group_3d' -INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' -OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' +INPUT_GROUP_3D = "input_group_3d" +WEIGHT_GROUP_3D = "weight_group_3d" +OUTPUT_GROUP_3D = "output_group_3d" +INPUT_X_WEIGHT_3D = "input_x_weight_group_3d" +OUTPUT_X_WEIGHT_3D = "output_x_weight_group_3d" # Attributes of tensor parallel parameters -IS_TENSOR_PARALLEL = 'is_tensor_parallel' -NUM_PARTITIONS = 'num_partitions' +IS_TENSOR_PARALLEL = "is_tensor_parallel" +NUM_PARTITIONS = "num_partitions" TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/legacy/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py index 8fdc3d6fea68..48bf8ab279e8 100644 --- a/colossalai/legacy/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -4,7 +4,6 @@ import random import socket from collections import Counter -from threading import local from typing import Union import numpy as np @@ -95,8 +94,9 @@ def detect_num_processes_on_current_node(self): @staticmethod def _check_parallel_mode(parallel_mode: ParallelMode): - assert isinstance(parallel_mode, ParallelMode), \ - f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' + assert isinstance( + parallel_mode, ParallelMode + ), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}" def get_global_rank(self): """Returns the global rank of the current device. @@ -239,8 +239,10 @@ def is_pipeline_first_stage(self, ignore_virtual=False): def is_pipeline_last_stage(self, ignore_virtual=False): if not ignore_virtual: - if self.virtual_pipeline_parallel_size \ - is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + if ( + self.virtual_pipeline_parallel_size is not None + and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1 + ): return False return self.is_last_rank(ParallelMode.PIPELINE) @@ -371,12 +373,12 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port (str): the master port for distributed training """ # initialize the default process group - init_method = f'tcp://[{host}]:{port}' + init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) - cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) @@ -398,10 +400,11 @@ def check_sanity(self): pps = self.pipeline_parallel_size tps = self.tensor_parallel_size ws = self.world_size - assert ws == dps * pps * \ - tps, f"Expected the world size {ws} to be equal to data" \ - f" parallel size ({dps}) * pipeline parallel size " \ - f"({pps}) * tensor parallel size ({tps})" + assert ws == dps * pps * tps, ( + f"Expected the world size {ws} to be equal to data" + f" parallel size ({dps}) * pipeline parallel size " + f"({pps}) * tensor parallel size ({tps})" + ) def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: @@ -409,10 +412,11 @@ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str) if isinstance(ele, int): setattr(self, attr_name, ele) elif isinstance(ele, dict): - setattr(self, attr_name, ele['size']) + setattr(self, attr_name, ele["size"]) else: raise NotImplementedError( - f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}' + ) def init_parallel_groups(self): """Initializes the parallel groups. @@ -427,10 +431,10 @@ def init_parallel_groups(self): self.world_size = world_size # set parallel size as attributes for global context - parallel_config = self.config.get('parallel', None) + parallel_config = self.config.get("parallel", None) if parallel_config is not None: - self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') - self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') + self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") + self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") # the user should not set the data parallel size manually # instead, it should be calculated based on other parallel config @@ -438,33 +442,33 @@ def init_parallel_groups(self): # get the tensor parallel mode and check tensor_parallel_mode = None - if parallel_config is not None and 'tensor' in \ - parallel_config and 'mode' in parallel_config['tensor']: - tensor_parallel_mode = parallel_config['tensor']['mode'] - assert tensor_parallel_mode in ALLOWED_MODES, \ - f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + if parallel_config is not None and "tensor" in parallel_config and "mode" in parallel_config["tensor"]: + tensor_parallel_mode = parallel_config["tensor"]["mode"] + assert ( + tensor_parallel_mode in ALLOWED_MODES + ), f"mode in the parallel config must be set to one of {ALLOWED_MODES}" env.mode = tensor_parallel_mode self.check_sanity() pg_init = [] # LSG: init data parallel process group for compatibility with other parallel module such as zero - pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["data"])) # LSG: init model parallel process group for compatibility with amp and clip grad - pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["model"])) if self.pipeline_parallel_size > 1: - pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) - pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["pipeline"])) + pg_init.append(dict(type=INITIALIZER_MAPPING["tensor"])) # init specific tensor parallel group if tensor_parallel_mode is not None: - tensor_parallel_cfg = parallel_config['tensor'].copy() + tensor_parallel_cfg = parallel_config["tensor"].copy() # remove duplicate parameters - tensor_parallel_cfg.pop('mode') - tensor_parallel_cfg.pop('size') + tensor_parallel_cfg.pop("mode") + tensor_parallel_cfg.pop("size") # add this config to initialize later pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) @@ -472,11 +476,16 @@ def init_parallel_groups(self): # run initialization of different process groups for initializer_cfg in pg_init: cfg = initializer_cfg.copy() - initializer_type = cfg.pop('type') - initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, **cfg) + initializer_type = cfg.pop("type") + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( + rank, + world_size, + self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, + **cfg, + ) parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): for args in parallel_setting: @@ -497,8 +506,7 @@ def is_initialized(self, parallel_mode: ParallelMode): return parallel_mode in self._groups def destroy(self): - """Destroys the current distributed parallel environment. - """ + """Destroys the current distributed parallel environment.""" for mode, group in self._groups.items(): if mode is not ParallelMode.GLOBAL: dist.destroy_process_group(group) @@ -519,7 +527,7 @@ def set_device(self, device_ordinal: int = None): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') + self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -552,21 +560,25 @@ def set_seed(self, seed: int): set_mode(ParallelMode.DATA) seeds = get_seeds() - seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) + seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info(f"initialized seed on rank {global_rank}, " - f"numpy: {seed}, python random: {seed}, {seed_str}," - f"the default parallel seed is {ParallelMode.DATA}.") + self._logger.info( + f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}." + ) else: if self._verbose: self._logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", - ranks=[0]) + ranks=[0], + ) self._logger.info( - 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', - ranks=[0]) + "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", + ranks=[0], + ) def set_virtual_pipeline_parallel_size(self, size): self.virtual_pipeline_parallel_size = size diff --git a/colossalai/legacy/context/parallel_mode.py b/colossalai/legacy/context/parallel_mode.py index 1cf6fa53dc1e..ceb52ff20da7 100644 --- a/colossalai/legacy/context/parallel_mode.py +++ b/colossalai/legacy/context/parallel_mode.py @@ -6,44 +6,43 @@ # parallel modes class ParallelMode(Enum): - """This is an enumeration class containing all possible parallel modes. - """ + """This is an enumeration class containing all possible parallel modes.""" - GLOBAL = 'global' + GLOBAL = "global" # common parallel - DATA = 'data' + DATA = "data" # model parallel - containing tensor and pipeline parallel groups # this is added to facilitate amp and grad clipping in hybrid parallel - MODEL = 'model' + MODEL = "model" # pipeline parallel - PIPELINE = 'pipe' + PIPELINE = "pipe" # containing all ranks in tensor parallel - TENSOR = 'tensor' + TENSOR = "tensor" # sequence parallel - SEQUENCE = 'sequence' - SEQUENCE_DP = 'sequence_dp' + SEQUENCE = "sequence" + SEQUENCE_DP = "sequence_dp" # 1D Parallel - PARALLEL_1D = '1d' + PARALLEL_1D = "1d" # 2D parallel - PARALLEL_2D_ROW = '2d_row' - PARALLEL_2D_COL = '2d_col' + PARALLEL_2D_ROW = "2d_row" + PARALLEL_2D_COL = "2d_col" # 3D parallel - PARALLEL_3D_INPUT = '3d_input' - PARALLEL_3D_WEIGHT = '3d_weight' - PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT = "3d_input" + PARALLEL_3D_WEIGHT = "3d_weight" + PARALLEL_3D_OUTPUT = "3d_output" PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" # 2.5D parallel - PARALLEL_2P5D_ROW = '2p5d_row' - PARALLEL_2P5D_COL = '2p5d_col' - PARALLEL_2P5D_DEP = '2p5d_dep' - PARALLEL_2P5D_XZ = '2p5d_xz' + PARALLEL_2P5D_ROW = "2p5d_row" + PARALLEL_2P5D_COL = "2p5d_col" + PARALLEL_2P5D_DEP = "2p5d_dep" + PARALLEL_2P5D_XZ = "2p5d_xz" diff --git a/colossalai/legacy/context/process_group_initializer/__init__.py b/colossalai/legacy/context/process_group_initializer/__init__.py index 48d52d7b9e52..a83165e40a8f 100644 --- a/colossalai/legacy/context/process_group_initializer/__init__.py +++ b/colossalai/legacy/context/process_group_initializer/__init__.py @@ -10,6 +10,14 @@ from .process_group_initializer import ProcessGroupInitializer __all__ = [ - 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', - 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' + "Initializer_Tensor", + "Initializer_Sequence", + "Initializer_Pipeline", + "Initializer_Data", + "Initializer_2p5D", + "Initializer_2D", + "Initializer_3D", + "Initializer_1D", + "ProcessGroupInitializer", + "Initializer_Model", ] diff --git a/colossalai/legacy/context/process_group_initializer/initializer_1d.py b/colossalai/legacy/context/process_group_initializer/initializer_1d.py index d853c6f06fc0..110a42cf880e 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_1d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_1d.py @@ -45,7 +45,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py index 39f6a46890b6..1c08d4d4296a 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -14,9 +14,10 @@ def _check_summa_env_var(summa_dim): env_summa_dim = env.summa_dim if env_summa_dim: - assert int(env_summa_dim) == summa_dim, \ - 'SUMMA_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_summa_dim) == summa_dim, ( + "SUMMA_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.summa_dim = summa_dim @@ -57,7 +58,7 @@ def init_dist_group(self): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -106,7 +107,7 @@ def init_dist_group(self): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -137,8 +138,9 @@ def __init__(self, *args, **kwargs): self.num_group = self.world_size // self.tensor_parallel_size self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) - assert self.tensor_parallel_size == self.summa_dim ** 2, \ - "2D summa dim should equal to tensor parallel size ^ 0.5" + assert ( + self.tensor_parallel_size == self.summa_dim**2 + ), "2D summa dim should equal to tensor parallel size ^ 0.5" _check_summa_env_var(self.summa_dim) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py index bb7a3509572f..b7d71b96334d 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py @@ -19,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): env_tesseract_dep = env.tesseract_dep if env_tesseract_dim and env_tesseract_dep: - assert int(env_tesseract_dim) == tesseract_dim, \ - 'TESSERACT_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' - assert int(env_tesseract_dep) == tesseract_dep, \ - 'TESSERACT_DEP has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_tesseract_dim) == tesseract_dim, ( + "TESSERACT_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) + assert int(env_tesseract_dep) == tesseract_dep, ( + "TESSERACT_DEP has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.tesseract_dim = tesseract_dim env.tesseract_dep = tesseract_dep @@ -50,8 +52,9 @@ def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. @@ -75,7 +78,7 @@ def init_dist_group(self): for i in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -129,7 +132,7 @@ def init_dist_group(self): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -183,7 +186,7 @@ def init_dist_group(self): for k in range(self.tesseract_dep) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -238,7 +241,7 @@ def init_dist_group(self): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -265,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer): depth (int): The depth of 2.5d parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int, depth: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + depth: int, + ): args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) self.tesseract_dep = depth - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) @@ -293,6 +305,6 @@ def init_dist_group(self): self.col_initializer.init_dist_group(), self.row_initializer.init_dist_group(), self.dep_initializer.init_dist_group(), - self.xz_initializer.init_dist_group() + self.xz_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/legacy/context/process_group_initializer/initializer_3d.py b/colossalai/legacy/context/process_group_initializer/initializer_3d.py index 3dfbf5223b12..5f96405e90aa 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_3d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_3d.py @@ -17,9 +17,10 @@ def _check_depth_env_var(depth): env_depth = env.depth_3d if env_depth: - assert int(env_depth) == depth, \ - 'DEPTH_3D has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_depth) == depth, ( + "DEPTH_3D has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.depth_3d = depth @@ -63,7 +64,7 @@ def init_dist_group(self): for k in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -114,7 +115,7 @@ def init_dist_group(self): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -165,7 +166,7 @@ def init_dist_group(self): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -219,7 +220,7 @@ def init_dist_group(self): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -273,7 +274,7 @@ def init_dist_group(self): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -302,8 +303,9 @@ def __init__(self, *args): super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) - assert self.tensor_parallel_size == self.depth ** 3, \ - f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' + assert ( + self.tensor_parallel_size == self.depth**3 + ), f"3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})" _check_depth_env_var(self.depth) self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) @@ -324,6 +326,6 @@ def init_dist_group(self): self.weight_initializer.init_dist_group(), self.output_initializer.init_dist_group(), self.input_x_weight_initializer.init_dist_group(), - self.output_x_weight_initializer.init_dist_group() + self.output_x_weight_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/legacy/context/process_group_initializer/initializer_data.py b/colossalai/legacy/context/process_group_initializer/initializer_data.py index b9dec4541dad..9c8bcf353c20 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_data.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_data.py @@ -43,7 +43,7 @@ def init_dist_group(self): for i in range(self.num_data_parallel_group): ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_model.py b/colossalai/legacy/context/process_group_initializer/initializer_model.py index 614ba372fbcc..6aeae27756e7 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_model.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_model.py @@ -45,7 +45,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py index e093333ad18a..3e69be75ff7e 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py @@ -38,10 +38,11 @@ def init_dist_group(self): for i in range(self.data_parallel_size): for j in range(self.pipeline_stage_size): pipe_ranks = list( - range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size) + ) pipe_group_size = len(pipe_ranks) pipe_group = dist.new_group(pipe_ranks) - group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group + group_cpu = dist.new_group(pipe_ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group if self.rank in pipe_ranks: local_rank = pipe_ranks.index(self.rank) @@ -50,7 +51,16 @@ def init_dist_group(self): cpu_group = group_cpu ranks_in_group = pipe_ranks dist_settings.append( - tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, - ParallelMode.PIPELINE))) + tuple( + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + ParallelMode.PIPELINE, + ) + ) + ) return dist_settings diff --git a/colossalai/legacy/context/process_group_initializer/initializer_sequence.py b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py index a6e26b6bcaa9..638b6d5ef2a6 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_sequence.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py @@ -46,7 +46,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.dp_size + j for j in range(self.dp_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -91,8 +91,14 @@ def init_dist_group(self): parallel_setting = [] - local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode = \ - self._sequence_initializer.init_dist_group() + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + mode, + ) = self._sequence_initializer.init_dist_group() # change mode to sequence mode = ParallelMode.SEQUENCE diff --git a/colossalai/legacy/context/process_group_initializer/initializer_tensor.py b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py index 3be89e52a812..cb19a43bd373 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_tensor.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py @@ -43,7 +43,7 @@ def init_dist_group(self): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/process_group_initializer.py b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py index 98150ce8e428..98b5d7fc3882 100644 --- a/colossalai/legacy/context/process_group_initializer/process_group_initializer.py +++ b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py @@ -18,8 +18,15 @@ class ProcessGroupInitializer(ABC): tensor_parallel_size (int): Size of tensor parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): self.rank = rank self.world_size = world_size self.data_parallel_size = data_parallel_size diff --git a/colossalai/legacy/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py index d64b993257c1..5e8d82922ddc 100644 --- a/colossalai/legacy/context/random/__init__.py +++ b/colossalai/legacy/context/random/__init__.py @@ -13,6 +13,15 @@ ) __all__ = [ - 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', - 'sync_states', 'moe_set_seed', 'reset_seeds' + "seed", + "set_mode", + "with_seed", + "add_seed", + "get_seeds", + "get_states", + "get_current_mode", + "set_seed_states", + "sync_states", + "moe_set_seed", + "reset_seeds", ] diff --git a/colossalai/legacy/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py index 4b5d5ef2fe55..be1d951d1229 100644 --- a/colossalai/legacy/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -100,7 +100,7 @@ def sync_states(): @contextmanager def seed(parallel_mode: ParallelMode): - """ A context for seed switch + """A context for seed switch Examples: @@ -162,6 +162,7 @@ def wrapper(*args, **kwargs): def moe_set_seed(seed): if torch.cuda.is_available(): from colossalai.legacy.core import global_context as gpc + global_rank = gpc.get_global_rank() diff_seed = seed + global_rank add_seed(ParallelMode.TENSOR, diff_seed, True) diff --git a/colossalai/legacy/context/random/seed_manager.py b/colossalai/legacy/context/random/seed_manager.py index b657ff7e1d32..c90e849631a1 100644 --- a/colossalai/legacy/context/random/seed_manager.py +++ b/colossalai/legacy/context/random/seed_manager.py @@ -42,7 +42,7 @@ def set_state(self, parallel_mode: ParallelMode, state: Tensor): Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. """ - assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' + assert parallel_mode in self._seed_states, f"Parallel mode {parallel_mode} is not found in the seed manager" self._seed_states[parallel_mode] = state def set_mode(self, parallel_mode: ParallelMode): @@ -71,9 +71,9 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. """ - assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided" if overwrite is False: - assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added" elif parallel_mode in self._seed_states: print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) diff --git a/colossalai/legacy/core.py b/colossalai/legacy/core.py index 0aaf1ee47730..80b6e4d25bd2 100644 --- a/colossalai/legacy/core.py +++ b/colossalai/legacy/core.py @@ -3,4 +3,4 @@ from colossalai.legacy.context.parallel_context import global_context -__all__ = ['global_context'] +__all__ = ["global_context"] diff --git a/colossalai/legacy/engine/__init__.py b/colossalai/legacy/engine/__init__.py index 158796befb31..581760748a16 100644 --- a/colossalai/legacy/engine/__init__.py +++ b/colossalai/legacy/engine/__init__.py @@ -1,4 +1,4 @@ from ._base_engine import Engine from .gradient_handler import * -__all__ = ['Engine'] +__all__ = ["Engine"] diff --git a/colossalai/legacy/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py index 930caf20c1dd..0954e2be3eb1 100644 --- a/colossalai/legacy/engine/_base_engine.py +++ b/colossalai/legacy/engine/_base_engine.py @@ -59,15 +59,17 @@ class Engine: `Run resnet cifar10 with engine `_. """ - def __init__(self, - model: Module, - optimizer: "OptimizerWrapper", - criterion: Optional[_Loss] = None, - gradient_handlers: Optional[List[BaseGradientHandler]] = None, - clip_grad_norm: float = 0.0, - ophook_list: Optional[List[BaseOpHook]] = None, - verbose: bool = True, - schedule: Optional[BaseSchedule] = None): + def __init__( + self, + model: Module, + optimizer: "OptimizerWrapper", + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, + clip_grad_norm: float = 0.0, + ophook_list: Optional[List[BaseOpHook]] = None, + verbose: bool = True, + schedule: Optional[BaseSchedule] = None, + ): self._model = model self._optimizer = optimizer self._criterion = criterion @@ -76,7 +78,7 @@ def __init__(self, self._logger = get_dist_logger() # state - self.training = True # default + self.training = True # default # build gradient handler if gradient_handlers: @@ -91,8 +93,9 @@ def __init__(self, # build schedule if schedule: - assert isinstance(schedule, BaseSchedule), \ - f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' + assert isinstance( + schedule, BaseSchedule + ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" self._schedule = schedule else: self._schedule = NonPipelineSchedule() @@ -149,13 +152,11 @@ def remove_hook(self, ophook: Type[BaseOpHook]) -> None: logger.warning(f"removing hooks is currently not supported") def zero_grad(self): - """Set the gradient of parameters to zero - """ + """Set the gradient of parameters to zero""" self.optimizer.zero_grad() def step(self): - """Execute parameter update - """ + """Execute parameter update""" self._all_reduce_gradients() self.optimizer.clip_grad_by_norm(self._clip_grad_norm) return self.optimizer.step() @@ -192,8 +193,7 @@ def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def _all_reduce_gradients(self): - """Handles all-reduce operations of gradients across different parallel groups. - """ + """Handles all-reduce operations of gradients across different parallel groups.""" for handler in self._gradient_handlers: handler.handle_gradient() @@ -208,13 +208,11 @@ def execute_schedule(self, data_iter: Iterable, **kwargs): return output, label, loss def train(self): - """Sets the model to training mode. - """ + """Sets the model to training mode.""" self.training = True self._model.train() def eval(self): - """Sets the model to evaluation mode. - """ + """Sets the model to evaluation mode.""" self.training = False self._model.eval() diff --git a/colossalai/legacy/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py index 670c26d06e55..e0835318ed9f 100644 --- a/colossalai/legacy/engine/gradient_accumulation/__init__.py +++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py @@ -14,17 +14,22 @@ ) __all__ = [ - 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', - 'GradAccumGradientHandler' + "accumulate_gradient", + "GradAccumDataloader", + "GradAccumOptimizer", + "GradAccumLrSchedulerByStep", + "GradAccumGradientHandler", ] -def accumulate_gradient(model: nn.Module, - optimizer: Optimizer, - dataloader: Iterable, - accumulate_size: int, - gradient_handlers: List[BaseGradientHandler] = None, - lr_scheduler: _LRScheduler = None): +def accumulate_gradient( + model: nn.Module, + optimizer: Optimizer, + dataloader: Iterable, + accumulate_size: int, + gradient_handlers: List[BaseGradientHandler] = None, + lr_scheduler: _LRScheduler = None, +): r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. Args: diff --git a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py index c2270dc53a50..9de0f6c0ffd9 100644 --- a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py @@ -272,8 +272,9 @@ class GradAccumGradientHandler: """ def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None: - assert isinstance(grad_handler, BaseGradientHandler), \ - f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}' + assert isinstance( + grad_handler, BaseGradientHandler + ), f"expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}" self.grad_handler = grad_handler self.accumulate_size = accumulate_size self.accumulate_step = 0 diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 2dea768bad7e..78928b138842 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -6,6 +6,10 @@ from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ - 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler', 'SequenceParallelGradientHandler' + "BaseGradientHandler", + "DataParallelGradientHandler", + "ZeROGradientHandler", + "PipelineSharedModuleGradientHandler", + "MoeGradientHandler", + "SequenceParallelGradientHandler", ] diff --git a/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py index 7d96dd8a88a6..e594bb00f96b 100644 --- a/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py @@ -22,4 +22,3 @@ def handle_gradient(self): """A method to accumulate gradients across different parallel groups. Users should write their own functions or just use the functions in pre-defined subclasses. """ - pass diff --git a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py index c692ee903442..3782adaf7187 100644 --- a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -20,8 +20,7 @@ class DataParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" # TODO: add memory buffer if gpc.data_parallel_size > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA)) diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index e7a6df2d8ae8..6a7224cff7bd 100644 --- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -42,5 +42,6 @@ def handle_gradient(self): for ep_size in epsize_param_dict: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group + ) diff --git a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 3eae7d58ac95..3a65f65abf73 100644 --- a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -26,17 +26,21 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in sub pipeline parallel groups. - """ + """A method running a all-reduce operation in sub pipeline parallel groups.""" if gpc.pipeline_parallel_size > 1: # bucketize and all-reduce buckets = defaultdict(lambda: defaultdict(list)) # Pack the buckets. for param in self._model.parameters(): - group = getattr(param, 'pipeline_shared_module_pg', None) - if param.requires_grad and group is not None and ( - (hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null()) - or param.grad is not None): + group = getattr(param, "pipeline_shared_module_pg", None) + if ( + param.requires_grad + and group is not None + and ( + (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null()) + or param.grad is not None + ) + ): tp = param.data.type() buckets[group][tp].append(param) @@ -44,7 +48,7 @@ def handle_gradient(self): for group, group_buckets in buckets.items(): for tp, bucket in group_buckets.items(): grads = [ - param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data + param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data for param in bucket ] coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) diff --git a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py index 38b7f5993b73..6d507bcc0269 100644 --- a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -20,7 +20,6 @@ class SequenceParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP)) diff --git a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py index 4ca7cd0b0702..63ec6e70ba06 100644 --- a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py @@ -16,6 +16,5 @@ class ZeROGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" self._optimizer.sync_grad() diff --git a/colossalai/legacy/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py index 0f2c039d7057..017231a9b4a8 100644 --- a/colossalai/legacy/engine/schedule/__init__.py +++ b/colossalai/legacy/engine/schedule/__init__.py @@ -2,4 +2,4 @@ from ._non_pipeline_schedule import NonPipelineSchedule from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape -__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] +__all__ = ["BaseSchedule", "NonPipelineSchedule", "PipelineSchedule", "InterleavedPipelineSchedule", "get_tensor_shape"] diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 7505a3eb20e3..4a3ccfda1bb5 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -47,7 +47,8 @@ def _move_to_device(self, data): data = {k: self._move_tensor(v) for k, v in data.items()} else: raise TypeError( - f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}" + ) return data def _get_batch_size(self, data): @@ -72,7 +73,7 @@ def load_batch(self, data_iter, to_gpu=True): Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label). """ if data_iter is None: - raise RuntimeError('Dataloader is not defined.') + raise RuntimeError("Dataloader is not defined.") batch_data = next(data_iter) if to_gpu: @@ -81,17 +82,17 @@ def load_batch(self, data_iter, to_gpu=True): return batch_data def pre_processing(self, engine): - """To perform actions before running the schedule. - """ - pass + """To perform actions before running the schedule.""" @abstractmethod - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function over a batch of dataset for training or evaluation. Args: @@ -101,7 +102,6 @@ def forward_backward_step(self, return_loss (bool, optional): If False, the loss won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned. """ - pass @staticmethod def _call_engine(engine, inputs): @@ -113,13 +113,14 @@ def _call_engine(engine, inputs): return engine(**inputs) else: TypeError( - f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}" + ) @staticmethod def _call_engine_criterion(engine, outputs, labels): - assert isinstance(outputs, - (torch.Tensor, list, tuple, - dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + assert isinstance( + outputs, (torch.Tensor, list, tuple, dict) + ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}" if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor): @@ -134,6 +135,8 @@ def _call_engine_criterion(engine, outputs, labels): elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") else: - raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + raise TypeError( + f"Expected model outputs and labels to be of type torch.Tensor ' \ '(which is auto-converted to tuple), list, tuple, or dict, ' \ - 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" + ) diff --git a/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py index b67893c1a0bb..08c6cfd60f28 100644 --- a/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py @@ -37,19 +37,22 @@ def __init__(self, data_process_func: Callable = None): if data_process_func: sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 1, \ - 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ - 'which is a tuple of tensors for the current batch, ' \ - 'i.e. data_process_func(dataloader_output).' + assert len(sig.parameters) == 1, ( + "The data_process_func only takes in one parameter for NonPipelineSchedule, " + "which is a tuple of tensors for the current batch, " + "i.e. data_process_func(dataloader_output)." + ) super().__init__(data_process_func) - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool = False, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool = False, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. @@ -64,8 +67,9 @@ def forward_backward_step(self, Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) if self.data_process_func: data, label = self.data_process_func(batch_data) diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 37eed82f8a28..4fc5040f6983 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -18,14 +18,18 @@ def get_tensor_shape(): - if hasattr(gpc.config, 'TENSOR_SHAPE'): + if hasattr(gpc.config, "TENSOR_SHAPE"): return gpc.config.TENSOR_SHAPE if not gpc.is_initialized(ParallelMode.PIPELINE): return None - if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( - gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): + if ( + hasattr(gpc.config, "SEQ_LENGTH") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "HIDDEN_SIZE") + ): if gpc.is_initialized(ParallelMode.DATA): dp_size = gpc.get_world_size(ParallelMode.DATA) else: @@ -35,8 +39,11 @@ def get_tensor_shape(): else: seq_size = 1 - tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, - gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) + tensor_shape = ( + gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, + gpc.config.HIDDEN_SIZE, + ) return tensor_shape else: return None @@ -49,7 +56,7 @@ def pack_return_tensors(return_tensors): elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: @@ -88,28 +95,31 @@ def data_process_func(stage_output, dataloader_output): """ - def __init__(self, - num_microbatches, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): - + def __init__( + self, + num_microbatches, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): # we need to make sure that the signature of the data_process_func is valid if data_process_func: sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 2, \ - 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ - 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ - 'i.e. data_process_func(stage_output, dataloader_output).' + assert len(sig.parameters) == 2, ( + "The data_process_func only takes in two parameters for NonPipelineSchedule, " + "which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, " + "i.e. data_process_func(stage_output, dataloader_output)." + ) super().__init__(data_process_func=data_process_func) - assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' + assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}" self.num_microbatches = num_microbatches self.dtype = torch.float - assert not isinstance(tensor_shape, - int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + assert not isinstance( + tensor_shape, int + ), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." if tensor_shape is None: self.tensor_shape = tensor_shape elif isinstance(tensor_shape, torch.Size): @@ -128,26 +138,25 @@ def load_batch(self, data_iter): # Pipeline schedule just puts data in memory batch_data = super().load_batch(data_iter, to_gpu=False) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches self.batch_data = batch_data def _get_data_slice(self, data, offset): if isinstance(data, torch.Tensor): - return data[offset:offset + self.microbatch_size] + return data[offset : offset + self.microbatch_size] elif isinstance(data, (list, tuple)): data_dict = {} for element in data: if isinstance(element, dict): - data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + data_dict.update({k: v[offset : offset + self.microbatch_size] for k, v in element.items()}) elif data_dict: - data_dict['label'] = element[offset:offset + self.microbatch_size] + data_dict["label"] = element[offset : offset + self.microbatch_size] if data_dict: return data_dict - return [val[offset:offset + self.microbatch_size] for val in data] + return [val[offset : offset + self.microbatch_size] for val in data] elif isinstance(data, dict): - return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} + return {k: v[offset : offset + self.microbatch_size] for k, v in data.items()} else: raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") @@ -180,8 +189,8 @@ def _call_engine(model, data): return model(*data) elif isinstance(data, dict): stage_output = None - if 'stage_output' in data: - stage_output = data.pop('stage_output') + if "stage_output" in data: + stage_output = data.pop("stage_output") if stage_output is None: return model(**data) elif isinstance(stage_output, torch.Tensor): @@ -198,7 +207,7 @@ def _call_engine(model, data): def _get_actual_forward_func(self, module): if isinstance(module, NaiveAMPModel): sig = inspect.signature(module.model.forward) - elif hasattr(module, 'colo_attr'): + elif hasattr(module, "colo_attr"): sig = inspect.signature(module.module.forward) else: sig = inspect.signature(module.forward) @@ -221,9 +230,9 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite _, label = micro_batch_data elif isinstance(micro_batch_data, dict): data = {} - data['stage_output'] = stage_output - if 'label' in micro_batch_data: - label = micro_batch_data.pop('label') + data["stage_output"] = stage_output + if "label" in micro_batch_data: + label = micro_batch_data.pop("label") else: label = None load_data = micro_batch_data @@ -263,7 +272,7 @@ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=T else: if isinstance(output_obj, torch.Tensor): self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" ) return output_obj @@ -325,12 +334,13 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches @@ -354,14 +364,12 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): if isinstance(output_obj, torch.Tensor): bt_shapes = output_obj.shape @@ -382,32 +390,29 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if forward_only: comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) if not last_iteration: - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) else: - output_obj_grad = comm.send_forward_recv_backward(output_obj, - bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.send_forward_recv_backward( + output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -424,10 +429,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo input_obj = None comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) else: - input_obj = comm.send_backward_recv_forward(input_obj_grad, - ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.send_backward_recv_forward( + input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Run cooldown backward passes. if not forward_only: @@ -435,9 +439,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = comm.recv_backward(bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.recv_backward( + bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) @@ -451,13 +455,14 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo class InterleavedPipelineSchedule(PipelineSchedule): - - def __init__(self, - num_microbatches: int, - num_model_chunks: int, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): + def __init__( + self, + num_microbatches: int, + num_model_chunks: int, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): """A helper schedule class for pipeline parallelism running environment. It uses interleaved 1F1B strategy. Other properties are similar as :class:`NonPipelineSchedule`. @@ -471,20 +476,25 @@ def __init__(self, scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. """ - assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ - 'num_microbatches must be an integer multiple of pipeline parallel world size' - assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ - f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' - super().__init__(num_microbatches, - data_process_func=data_process_func, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather_tensors) + assert ( + num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0 + ), "num_microbatches must be an integer multiple of pipeline parallel world size" + assert ( + isinstance(num_model_chunks, int) and num_model_chunks > 0 + ), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}" + super().__init__( + num_microbatches, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors, + ) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_rank(0) self.num_model_chunks = num_model_chunks def pre_processing(self, engine): from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + if isinstance(engine.model, ShardedModelV2): self.dtype = torch.half elif isinstance(engine.model[0], NaiveAMPModel): @@ -494,7 +504,7 @@ def pre_processing(self, engine): model = model.model sig = inspect.signature(model.forward) for p in sig.parameters.values(): - assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported" def load_batch(self, data_iter): super().load_batch(data_iter) @@ -506,13 +516,9 @@ def load_micro_batch(self, model_chunk_id): self.microbatch_offset[model_chunk_id] += self.microbatch_size return self._move_to_device(data) - def _forward_step(self, - engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=True, - accum_loss=None): + def _forward_step( + self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None + ): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -528,8 +534,9 @@ def _forward_step(self, Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ micro_batch_data = self.load_micro_batch(model_chunk_id) - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, - engine.model[model_chunk_id]) + data, label = self._get_data_label_for_current_step( + input_obj, micro_batch_data, engine.criterion, engine.model[model_chunk_id] + ) output_obj = self._call_engine(engine.model[model_chunk_id], data) @@ -546,7 +553,7 @@ def _forward_step(self, else: if isinstance(output_obj, torch.Tensor): self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" ) return output_obj @@ -566,8 +573,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. The loss would be returned only in the last stage. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) model = engine.model input_objs = [[] for _ in range(len(model))] @@ -605,19 +613,17 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: - num_warmup_microbatches = \ - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = \ - num_microbatches - num_warmup_microbatches + num_microbatches_remaining = num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: - model_chunk_id = (num_model_chunks - model_chunk_id - 1) + model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def _forward_step_helper(microbatch_id): @@ -629,16 +635,17 @@ def _forward_step_helper(microbatch_id): # forward step if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == \ - len(output_objs[model_chunk_id]): + if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]): input_objs[model_chunk_id].append(None) input_obj = input_objs[model_chunk_id][-1] - output_obj = self._forward_step(engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss, + ) output_objs[model_chunk_id].append(output_obj) # if forward-only, no need to save tensors for a backward pass @@ -670,8 +677,8 @@ def _backward_step_helper(microbatch_id): if not gpc.is_pipeline_first_stage(): input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) input_objs[0].append( - comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) + ) for k in range(num_warmup_microbatches): model_chunk_id = get_model_chunk_id(k, forward=True) @@ -683,8 +690,9 @@ def _backward_step_helper(microbatch_id): output_obj_shapes[model_chunk_id] = [] for out_tensor in output_obj: output_obj_shapes[model_chunk_id].append(out_tensor.shape) - send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, - send_tensor_shape_flags[model_chunk_id]) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta( + output_obj, send_tensor_shape_flags[model_chunk_id] + ) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True @@ -701,34 +709,36 @@ def _backward_step_helper(microbatch_id): with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): if not gpc.is_pipeline_first_stage(): input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( - input_obj_shapes[next_forward_model_chunk_id]) + input_obj_shapes[next_forward_model_chunk_id] + ) # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: + if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_obj_grad = None recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): recv_next = False output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) output_obj_grads[num_model_chunks - 1].append(output_obj_grad) else: - input_obj = \ - comm.send_forward_recv_forward( - output_obj, - input_shape, - recv_prev=recv_prev, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.send_forward_recv_forward( + output_obj, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) input_objs[next_forward_model_chunk_id].append(input_obj) # Run 1F1B in steady state. @@ -771,8 +781,9 @@ def _backward_step_helper(microbatch_id): recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), - forward=False) + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 @@ -787,14 +798,16 @@ def _backward_step_helper(microbatch_id): input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None # Communicate objs. - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) # Put input_obj and output_obj_grad in data structures in the # right location. @@ -807,8 +820,10 @@ def _backward_step_helper(microbatch_id): if not forward_only: if all_warmup_microbatches: output_obj_grads[num_model_chunks - 1].append( - comm.recv_backward(output_obj_shapes[num_model_chunks - 1], - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.recv_backward( + output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors + ) + ) for k in range(num_microbatches_remaining, num_microbatches): input_obj_grad = _backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) @@ -820,11 +835,14 @@ def _backward_step_helper(microbatch_id): recv_next = False output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None output_obj_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward(input_obj_grad, - output_shape, - recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.send_backward_recv_backward( + input_obj_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + ) if len(return_tensors) > 0: output, label = pack_return_tensors(return_tensors) diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index bf8b599a81ae..867c3dfa819b 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -21,7 +21,7 @@ def pack_return_tensors(return_tensors): elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: @@ -59,12 +59,9 @@ def data_process_func(stage_output, dataloader_output): """ - def forward_backward_step(self, - engine: Engine, - data_iter: Iterable, - forward_only=False, - return_loss=True, - return_output_label=True) -> Tuple[torch.Tensor]: + def forward_backward_step( + self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True + ) -> Tuple[torch.Tensor]: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -80,14 +77,15 @@ def forward_backward_step(self, Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) # num_warmup_microbatches is the step when not all the processes are working - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches @@ -109,11 +107,9 @@ def forward_backward_step(self, for i in range(num_warmup_microbatches): input_obj = comm.recv_forward() - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) comm.send_forward(output_obj) @@ -129,13 +125,11 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if forward_only: comm.send_forward(output_obj) diff --git a/colossalai/legacy/global_variables.py b/colossalai/legacy/global_variables.py index 61b31965e2e6..93cd5e60fa61 100644 --- a/colossalai/legacy/global_variables.py +++ b/colossalai/legacy/global_variables.py @@ -12,19 +12,21 @@ def __new__(cls, *args, **kwargs): def __init__(self, *args, **kwargs): self.load(*args, **kwargs) - def load(self, - mode: Optional[str] = None, - vocab_parallel: bool = False, - parallel_input_1d: bool = False, - summa_dim: int = None, - tesseract_dim: int = None, - tesseract_dep: int = None, - depth_3d: int = None, - input_group_3d=None, - weight_group_3d=None, - output_group_3d=None, - input_x_weight_group_3d=None, - output_x_weight_group_3d=None): + def load( + self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None, + ): self.mode = mode self.vocab_parallel = vocab_parallel self.parallel_input_1d = parallel_input_1d @@ -39,18 +41,20 @@ def load(self, self.output_x_weight_group_3d = output_x_weight_group_3d def save(self): - return dict(mode=self.mode, - vocab_parallel=self.vocab_parallel, - parallel_input_1d=self.parallel_input_1d, - summa_dim=self.summa_dim, - tesseract_dim=self.tesseract_dim, - tesseract_dep=self.tesseract_dep, - depth_3d=self.depth_3d, - input_group_3d=self.input_group_3d, - weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d, - input_x_weight_group_3d=self.input_x_weight_group_3d, - output_x_weight_group_3d=self.output_x_weight_group_3d) + return dict( + mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d, + ) tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 2c253adbaf38..ce9c626553bf 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -47,25 +47,27 @@ def get_default_parser(): Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. """ parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='path to the config file') - parser.add_argument('--host', type=str, help='the master address for distributed training') - parser.add_argument('--port', type=int, help='the master port for distributed training') - parser.add_argument('--world_size', type=int, help='world size for distributed training') - parser.add_argument('--rank', type=int, help='rank for the default process group') - parser.add_argument('--local_rank', type=int, help='local rank on the node') - parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') + parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument("--host", type=str, help="the master address for distributed training") + parser.add_argument("--port", type=int, help="the master port for distributed training") + parser.add_argument("--world_size", type=int, help="world size for distributed training") + parser.add_argument("--rank", type=int, help="rank for the default process group") + parser.add_argument("--local_rank", type=int, help="local rank on the node") + parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") return parser -def launch(config: Union[str, Path, Config, Dict], - rank: int, - world_size: int, - host: str, - port: int, - backend: str = 'nccl', - local_rank: int = None, - seed: int = 1024, - verbose: bool = True): +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -88,8 +90,9 @@ def launch(config: Union[str, Path, Config, Dict], gpc.verbose = verbose # set config - assert isinstance(config, (Config, str, Path, dict)), \ - f'expected argument config to be Config, str or Path, but got {type(config)}' + assert isinstance( + config, (Config, str, Path, dict) + ), f"expected argument config to be Config, str or Path, but got {type(config)}" if not isinstance(config, Config) and isinstance(config, dict): config = Config(config) if isinstance(config, (str, Path)): @@ -115,18 +118,21 @@ def launch(config: Union[str, Path, Config, Dict], if verbose: logger = get_dist_logger() logger.info( - f'Distributed environment is initialized, ' - f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' - f'tensor parallel size: {gpc.tensor_parallel_size}', - ranks=[0]) - - -def launch_from_slurm(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + f"Distributed environment is initialized, " + f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, " + f"tensor parallel size: {gpc.tensor_parallel_size}", + ranks=[0], + ) + + +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM @@ -139,29 +145,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" ) - launch(config=config, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_openmpi(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI @@ -174,29 +184,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_torch(config: Union[str, Path, Config, Dict], - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -207,35 +218,39 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def initialize(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, - lr_scheduler: Optional[_LRScheduler] = None, - ophooks: Optional[List[BaseOpHook]] = None, - verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def initialize( + model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True, +) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: """Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. @@ -267,30 +282,30 @@ def initialize(model: nn.Module, f"\n========== Your Config ========\n" f"{pprint.pformat(gpc.config)}\n" f"================================\n", - ranks=[0]) + ranks=[0], + ) # cudnn - cudnn_benchmark = config.get('cudnn_benchmark', False) - cudnn_deterministic = config.get('cudnn_deterministic', False) + cudnn_benchmark = config.get("cudnn_benchmark", False) + cudnn_deterministic = config.get("cudnn_deterministic", False) torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.deterministic = cudnn_deterministic if verbose: logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) # zero - use_zero = hasattr(gpc.config, 'zero') + use_zero = hasattr(gpc.config, "zero") if use_zero: - zero_cfg = gpc.config.get('zero', None) + zero_cfg = gpc.config.get("zero", None) if zero_cfg is not None: cfg_ = zero_cfg.copy() else: cfg_ = {} - optimizer_config = zero_cfg.get('optimizer_config', None) - model_config = zero_cfg.get('model_config', None) - model, optimizer = convert_to_zero_v2(model, - optimizer, - model_config=model_config, - optimizer_config=optimizer_config) + optimizer_config = zero_cfg.get("optimizer_config", None) + model_config = zero_cfg.get("model_config", None) + model, optimizer = convert_to_zero_v2( + model, optimizer, model_config=model_config, optimizer_config=optimizer_config + ) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) else: @@ -316,38 +331,38 @@ def initialize(model: nn.Module, logger.warning( "The parameters of models is not automatically synchronized.\n" "Please make sure that all parameters are the same in data parallel group.", - ranks=[0]) + ranks=[0], + ) # check amp and zero - fp16_cfg = gpc.config.get('fp16', None) + fp16_cfg = gpc.config.get("fp16", None) if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: raise ConfigException( - "It is not allowed to set fp16 and zero configuration in your config file at the same time") + "It is not allowed to set fp16 and zero configuration in your config file at the same time" + ) # clip grad norm - clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + clip_grad_norm = gpc.config.get("clip_grad_norm", 0.0) # initialize amp amp_mode = None if fp16_cfg is not None and fp16_cfg.mode is not None: cfg_ = fp16_cfg.copy() - amp_mode = cfg_.pop('mode') + amp_mode = cfg_.pop("mode") if is_using_pp(): - assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' + assert amp_mode == AMP_TYPE.NAIVE, "Pipeline only support NaiveAMP currently" if amp_mode == AMP_TYPE.NAIVE: - cfg_['clip_grad_norm'] = clip_grad_norm - model, optimizer, criterion = convert_to_amp(model=model, - optimizer=optimizer, - criterion=criterion, - mode=amp_mode, - amp_config=cfg_) + cfg_["clip_grad_norm"] = clip_grad_norm + model, optimizer, criterion = convert_to_amp( + model=model, optimizer=optimizer, criterion=criterion, mode=amp_mode, amp_config=cfg_ + ) # get torch ddp config - torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + torch_ddp_cfg = gpc.config.get("torch_ddp", dict()) # gradient handler - gradient_handler_cfg = gpc.config.get('gradient_handler', None) + gradient_handler_cfg = gpc.config.get("gradient_handler", None) if gradient_handler_cfg is None: # if gradient handler is not specified in the configuration file, # check in the following order @@ -355,54 +370,63 @@ def initialize(model: nn.Module, # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp # 3. if using pipeline and dp size larger than 1, use data parallel grad handler if isinstance(optimizer, ShardedOptimizerV2): - gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + gradient_handler_cfg = [dict(type="ZeROGradientHandler")] if verbose: logger.info( "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type='MoeGradientHandler')] + gradient_handler_cfg = [dict(type="MoeGradientHandler")] if verbose: logger.info( "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) elif is_using_sequence(): - model = DDP(model, - process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', - ranks=[0]) + logger.info( + "Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism", ranks=[0] + ) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, - process_group=gpc.get_group(ParallelMode.DATA), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + logger.info("Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism", ranks=[0]) elif is_using_ddp(): - gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + gradient_handler_cfg = [dict(type="DataParallelGradientHandler")] if verbose: logger.info( "Data parallel training is detected when using pipeline parallel, " "DataParallelGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) # add pipeline parallel gradient handler, if pipeline shared module is detected for param in model.parameters(): - if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if getattr(param, "pipeline_shared_module_pg", None) is not None: if gradient_handler_cfg is None: - gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + gradient_handler_cfg = [dict(type="PipelineSharedModuleGradientHandler")] else: - gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + gradient_handler_cfg.append(dict(type="PipelineSharedModuleGradientHandler")) if verbose: logger.info( "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) break else: if not isinstance(gradient_handler_cfg, list): @@ -418,7 +442,7 @@ def initialize(model: nn.Module, # initialize schedule for engine if is_using_pp(): tensor_shape = get_tensor_shape() - use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + use_interleaved = hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") if gpc.is_initialized(ParallelMode.PARALLEL_1D): scatter_gather = True else: @@ -426,14 +450,16 @@ def initialize(model: nn.Module, if use_interleaved: if isinstance(model, nn.Sequential): model = nn.ModuleList([model]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) + schedule = InterleavedPipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather, + ) else: - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) + schedule = PipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather + ) else: schedule = NonPipelineSchedule() @@ -443,7 +469,8 @@ def initialize(model: nn.Module, logger.warning( "No PyTorch DDP or gradient handler is set up, please make sure you do not need " "to all-reduce the gradients after a training step.", - ranks=[0]) + ranks=[0], + ) else: gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] @@ -452,7 +479,7 @@ def initialize(model: nn.Module, optimizer = OptimizerWrapper(optim=optimizer) # gradient accumulation - grad_accum_size = gpc.config.get('gradient_accumulation', None) + grad_accum_size = gpc.config.get("gradient_accumulation", None) if grad_accum_size is not None: optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( model=model, @@ -460,13 +487,16 @@ def initialize(model: nn.Module, dataloader=train_dataloader, accumulate_size=grad_accum_size, gradient_handlers=gradient_handlers, - lr_scheduler=lr_scheduler) - engine = Engine(model=model, - optimizer=optimizer, - criterion=criterion, - gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm, - ophook_list=ophooks, - schedule=schedule) + lr_scheduler=lr_scheduler, + ) + engine = Engine( + model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule, + ) return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py index a4228fa2116e..b6a99f855a4c 100644 --- a/colossalai/legacy/nn/_ops/_utils.py +++ b/colossalai/legacy/nn/_ops/_utils.py @@ -41,7 +41,7 @@ def _reduce(input_, pg: ProcessGroup): # skip if only one rank involved if pg.tp_world_size() == 1: return input_ - assert input_.device.type == 'cuda' + assert input_.device.type == "cuda" group = pg.tp_process_group() dist.all_reduce(input_, group=group) @@ -56,9 +56,10 @@ def _split(input_, pg: ProcessGroup, dim=-1): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = pg.tp_local_rank() @@ -77,7 +78,7 @@ def _gather(input_, pg: ProcessGroup, dim=-1): rank = pg.tp_local_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ - assert input_.device.type == 'cuda' + assert input_.device.type == "cuda" group = pg.tp_process_group() torch.distributed.all_gather(tensor_list, input_, group=group) @@ -203,7 +204,7 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: return x # TODO: enabling mpi backend to support CPU all_to_all - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" shapes = list(x.size()) shapes[scatter_dim] = shapes[scatter_dim] // world_size @@ -216,7 +217,6 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: class _DualAllToAll(torch.autograd.Function): - @staticmethod def forward(ctx, x, pg, scatter_dim, gather_dim): ctx.scatter_dim = scatter_dim @@ -236,16 +236,14 @@ def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): # table wise embedding shard -def _all_to_all_for_tablewise(x: torch.Tensor, - pg: ProcessGroup, - scatter_strides: List[int], - gather_strides: List[int], - forward=True) -> torch.Tensor: +def _all_to_all_for_tablewise( + x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True +) -> torch.Tensor: world_size = pg.tp_world_size() rank = pg.tp_local_rank() if world_size == 1: return x - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" if forward: scatter_list = list(x.split(scatter_strides, 0)) gather_list = [ @@ -266,7 +264,6 @@ def _all_to_all_for_tablewise(x: torch.Tensor, class _DualAllToAllForTablewise(torch.autograd.Function): - @staticmethod def forward(ctx, x, pg, scatter_strides, gather_strides): ctx.pg = pg @@ -276,8 +273,12 @@ def forward(ctx, x, pg, scatter_strides, gather_strides): @staticmethod def backward(ctx, grad): - return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, - forward=False), None, None, None + return ( + _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), + None, + None, + None, + ) def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): diff --git a/colossalai/legacy/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py index 01fd9b3e8943..66abc6fb1fd1 100644 --- a/colossalai/legacy/nn/layer/base_layer.py +++ b/colossalai/legacy/nn/layer/base_layer.py @@ -10,44 +10,54 @@ class ParallelLayer(nn.Module): - global_state_dict: bool = True def __init__(self): super().__init__() - self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( - ParallelMode.DATA) - self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( - ParallelMode.DATA) + self.data_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + ) + self.data_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(ParallelMode.DATA) + ) - self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( - ParallelMode.TENSOR) - self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( - ParallelMode.TENSOR) + self.tensor_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(ParallelMode.TENSOR) + ) + self.tensor_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + ) - self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + self.pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + self.pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) - def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + def _load_from_global_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _save_to_global_state_dict(self, destination, prefix, keep_vars): return super()._save_to_state_dict(destination, prefix, keep_vars) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): if self.global_state_dict: if gpc.get_local_rank(ParallelMode.TENSOR) != 0: missing_keys.clear() unexpected_keys.clear() - return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs) - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + return self._load_from_global_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _save_to_state_dict(self, destination, prefix, keep_vars): if self.global_state_dict: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/__init__.py b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py index ed743820ddbc..7c5449ff5578 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/__init__.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py @@ -4,4 +4,4 @@ from .linear import Classifier, Linear from .normalization import LayerNorm -__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] +__all__ = ["Linear", "Classifier", "Embedding", "PatchEmbedding", "LayerNorm", "Dropout", "partition_batch"] diff --git a/colossalai/legacy/nn/layer/colossalai_layer/_utils.py b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py index 677cb0e7ac42..98255142a846 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py @@ -6,7 +6,7 @@ from ..parallel_3d._operation import split_batch_3d from ..utils import get_tensor_parallel_mode -_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} +_parallel_split_batch = {"2d": split_batch_2d, "2.5d": split_batch_2p5d, "3d": split_batch_3d} def partition_batch(input_) -> Tensor: @@ -21,7 +21,6 @@ def partition_batch(input_) -> Tensor: class ColossalaiModule(nn.Module): - def __init__(self, module: nn.Module, **kwargs): super().__init__() self.module = module @@ -29,7 +28,7 @@ def __init__(self, module: nn.Module, **kwargs): setattr(self, k, v) def __getattr__(self, name: str): - if name == 'module': + if name == "module": return super().__getattr__(name) elif hasattr(self.module, name): return getattr(self.module, name) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py index 7b0481a3f53c..ad6fcc2d8bf4 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py @@ -24,7 +24,7 @@ def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super().__init__(drop, tensor_parallel=tensor_parallel) def forward(self, *args): - if self.tensor_parallel in [None, '1d']: + if self.tensor_parallel in [None, "1d"]: return super().forward(*args) else: with seed(ParallelMode.TENSOR): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index 28bcb7ffefb0..e1db0fe98a02 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -15,25 +15,25 @@ from ._utils import ColossalaiModule _parallel_embedding = { - '1d': Embedding1D, - '2d': Embedding2D, - '2.5d': Embedding2p5D, - '3d': Embedding3D, + "1d": Embedding1D, + "2d": Embedding2D, + "2.5d": Embedding2p5D, + "3d": Embedding3D, } _vocab_parallel_embedding = { - '1d': VocabParallelEmbedding1D, - '2d': VocabParallelEmbedding2D, - '2.5d': VocabParallelEmbedding2p5D, - '3d': VocabParallelEmbedding3D + "1d": VocabParallelEmbedding1D, + "2d": VocabParallelEmbedding2D, + "2.5d": VocabParallelEmbedding2p5D, + "3d": VocabParallelEmbedding3D, } _parallel_patchembedding = { None: VanillaPatchEmbedding, - '1d': PatchEmbedding1D, - '2d': PatchEmbedding2D, - '2.5d': PatchEmbedding2p5D, - '3d': PatchEmbedding3D + "1d": PatchEmbedding1D, + "2d": PatchEmbedding2D, + "2.5d": PatchEmbedding2p5D, + "3d": PatchEmbedding3D, } @@ -67,19 +67,24 @@ class Embedding(ColossalaiModule): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: dtype = None, - weight_initializer: Callable = init.normal_(), - vocab_parallel_limit: int = 2048, - *args, - **kwargs) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs, + ) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, - **kwargs).to(dtype).to(get_current_device()) + embed = ( + nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) + .to(dtype) + .to(get_current_device()) + ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: embed = _parallel_embedding[tensor_parallel]( @@ -135,7 +140,7 @@ def __init__( flatten: bool = True, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_() + position_embed_initializer: Callable = init.zeros_(), ) -> None: tensor_parallel = get_tensor_parallel_mode() embed = _parallel_patchembedding[tensor_parallel]( diff --git a/colossalai/legacy/nn/layer/colossalai_layer/linear.py b/colossalai/legacy/nn/layer/colossalai_layer/linear.py index c05ceb66ce25..aa4863e28b81 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/linear.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/linear.py @@ -5,7 +5,6 @@ from torch import dtype, nn from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import * from ..parallel_2d import * @@ -15,21 +14,21 @@ from ..vanilla import * from ._utils import ColossalaiModule -_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} +_parallel_linear = {None: VanillaLinear, "1d": Linear1D, "2d": Linear2D, "2.5d": Linear2p5D, "3d": Linear3D} _parallel_classifier = { None: VanillaClassifier, - '1d': Classifier1D, - '2d': Classifier2D, - '2.5d': Classifier2p5D, - '3d': Classifier3D + "1d": Classifier1D, + "2d": Classifier2D, + "2.5d": Classifier2p5D, + "3d": Classifier3D, } _vocab_parallel_classifier = { - '1d': VocabParallelClassifier1D, - '2d': VocabParallelClassifier2D, - '2.5d': VocabParallelClassifier2p5D, - '3d': VocabParallelClassifier3D + "1d": VocabParallelClassifier1D, + "2d": VocabParallelClassifier2D, + "2.5d": VocabParallelClassifier2p5D, + "3d": VocabParallelClassifier3D, } @@ -65,19 +64,21 @@ class Linear(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: tensor_parallel = get_tensor_parallel_mode() linear_cls = _parallel_linear[tensor_parallel] - gather_output = kwargs.pop('gather_output', None) - if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available - kwargs['gather_output'] = gather_output + gather_output = kwargs.pop("gather_output", None) + if "gather_output" in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs["gather_output"] = gather_output layer = linear_cls( in_features, out_features, @@ -108,15 +109,17 @@ class Classifier(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - vocab_parallel_limit: int = 2048) -> None: + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048, + ) -> None: tensor_parallel = get_tensor_parallel_mode() if num_classes <= vocab_parallel_limit or tensor_parallel is None: layer = _parallel_classifier[tensor_parallel]( diff --git a/colossalai/legacy/nn/layer/parallel_1d/__init__.py b/colossalai/legacy/nn/layer/parallel_1d/__init__.py index 9cffd4d339f5..35e9ec40d100 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_1d/__init__.py @@ -12,6 +12,14 @@ ) __all__ = [ - 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', - 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' + "Linear1D", + "Linear1D_Col", + "Linear1D_Row", + "Embedding1D", + "Dropout1D", + "Classifier1D", + "VocabParallelClassifier1D", + "VocabParallelEmbedding1D", + "LayerNorm1D", + "PatchEmbedding1D", ] diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index db9dfa3667b4..f01da97ba39a 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -21,7 +21,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability - """ + """ @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -30,8 +30,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -39,11 +40,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None diff --git a/colossalai/legacy/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py index 15b41e305cba..93b476e811a4 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py @@ -47,9 +47,10 @@ def _split(input_, parallel_mode, dim=-1): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = gpc.get_local_rank(parallel_mode) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index db7986b8e8e5..8304cd2e1eb7 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -27,7 +27,7 @@ from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from ..vanilla import VanillaPatchEmbedding from ._operation import linear_with_async_comm from ._utils import ( gather_forward_split_backward, @@ -41,6 +41,7 @@ Fast_LN = None try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm except ImportError: pass @@ -67,33 +68,39 @@ class Linear1D(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): parallel_input = get_parallel_input() if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Col( + in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) else: - layer = Linear1D_Row(in_features, - out_features, - bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Row( + in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) super().__init__(layer) @@ -114,8 +121,30 @@ class LayerNorm1D(ColossalaiModule): """ _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, ] def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): @@ -125,6 +154,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): norm = None try: from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) except ImportError: norm = LayerNorm(normalized_shape, eps=eps).to(dtype) @@ -132,8 +162,8 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): def _load_from_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -171,14 +201,16 @@ class Classifier1D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -189,7 +221,7 @@ def __init__(self, # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -221,8 +253,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -235,50 +267,46 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) output_parallel = F.linear(input_, self.weight) @@ -307,15 +335,17 @@ class VocabParallelClassifier1D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -327,7 +357,7 @@ def __init__(self, # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -360,8 +390,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -374,43 +404,37 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # Matrix multiply. @@ -449,15 +473,17 @@ class Linear1D_Col(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -467,13 +493,13 @@ def __init__(self, self.skip_bias_add = skip_bias_add if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -500,8 +526,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -513,41 +539,35 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) input_parallel = input_ @@ -569,7 +589,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: @LAYERS.register_module class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism + r"""Linear layer with row parallelism Args: in_features (int): size of each input sample. @@ -588,16 +608,18 @@ class Linear1D_Row(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -609,14 +631,14 @@ def __init__(self, self.skip_bias_add = skip_bias_add if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -647,8 +669,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -660,48 +682,44 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) if self.stream_chunk_num > 1: @@ -712,9 +730,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=gpc.get_group(ParallelMode.PARALLEL_1D), async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -763,14 +781,16 @@ class Embedding1D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings @@ -782,7 +802,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -804,31 +825,31 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True} + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -867,14 +888,16 @@ class VocabParallelEmbedding1D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -889,7 +912,8 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -906,34 +930,38 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True} + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: @@ -943,11 +971,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) return output @@ -1002,30 +1031,34 @@ class PatchEmbedding1D(ColossalaiModule): :type position_embed_initializer: typing.Callable, optional """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + embed = VanillaPatchEmbedding( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) super().__init__(embed) def _load_from_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + param_keys = [prefix + "weight", prefix + "bias", prefix + "cls_token", prefix + "pos_embed"] if gpc.get_local_rank(ParallelMode.TENSOR) == 0: for key in param_keys: param = state_dict.pop(key, None) diff --git a/colossalai/legacy/nn/layer/parallel_2d/__init__.py b/colossalai/legacy/nn/layer/parallel_2d/__init__.py index 9c65f3608710..8d29c66b3a24 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_2d/__init__.py @@ -10,6 +10,13 @@ ) __all__ = [ - 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', - 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' + "split_batch_2d", + "reduce_by_batch_2d", + "Linear2D", + "LayerNorm2D", + "Classifier2D", + "PatchEmbedding2D", + "Embedding2D", + "VocabParallelEmbedding2D", + "VocabParallelClassifier2D", ] diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index 43e14d4a47a5..f1eff7128e7a 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,9 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter +from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.utils import get_current_device @@ -49,17 +48,30 @@ def matmul_2d( col_rank = gpc.get_local_rank(row_parallel_mode) data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = summa_dim**2 - return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + return Matmul_AB_2D( + a, + b, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Classifier2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -132,10 +144,21 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None -def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], - row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: +def classifier_2d( + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""2D parallel classifier. Args: @@ -157,9 +180,21 @@ def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) + return _Classifier2D.apply( + A, + B, + bias, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class Matmul_AB_2D(torch.autograd.Function): @@ -205,8 +240,7 @@ def forward( # B: [h / q, s / q] # C: [b / q, s, s / q] -> [(b * s) / q, s / q] - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -226,10 +260,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opb = [None] * 2 @@ -278,14 +318,34 @@ def forward( def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -329,9 +389,7 @@ def forward( pipeline_parallel_size: int, tensor_parallel_size: int, ) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -351,10 +409,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opb = [None] * 2 opr = [None] * 2 @@ -412,14 +476,34 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_AB_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + output_grad, + A, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -462,9 +546,7 @@ def forward( pipeline_parallel_size: int, tensor_parallel_size: int, ) -> Tensor: - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -484,10 +566,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opr = [None] * 2 @@ -545,19 +633,38 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2D.apply( + B, + output_grad, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None class _Add_Bias_2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -608,10 +715,20 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return output_grad, grad, None, None, None, None, None, None, None, None, None, None -def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: +def add_bias_2d( + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Matrix add bias: :math:`C = A + b`. Args: @@ -633,17 +750,34 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) + return _Add_Bias_2D.apply( + input_, + bias, + output_size_per_partition, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Layernorm_2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx: Any, + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + ) -> Tensor: input_ = input_ - E_x # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) ctx.normalized_shape = hidden_size @@ -657,7 +791,7 @@ def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: i @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_parallel_mode = ctx.row_parallel_mode - col_parallel_mode = ctx.col_parallel_mode + ctx.col_parallel_mode x, Var_x = ctx.saved_tensors # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) @@ -676,8 +810,14 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return input_grad, None, None, None, None, None -def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: +def layernorm_2d( + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, +) -> Tensor: r"""Layernorm. Args: @@ -696,7 +836,6 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r class _AllGatherTensor2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: @@ -744,15 +883,14 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: if world_size <= 1: return input_ - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + ].contiguous() class _ReduceTensor2D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -777,7 +915,6 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -803,14 +940,12 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) class _ReduceByBatch2D(torch.autograd.Function): - @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py index 87ba1bf69691..fe18af26f88f 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_utils.py @@ -6,15 +6,17 @@ def get_summa_dim_from_env() -> int: try: summa_dim = env.summa_dim - assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' + assert summa_dim > 0, "SUMMA_DIM must be larger than zero" return summa_dim - except KeyError as e: - raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "SUMMA_DIM is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def assert_summa_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ - 'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and gpc.is_initialized( + ParallelMode.PARALLEL_2D_ROW + ), "Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer" diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index 893bc74b57d9..3b2e032e5127 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -55,14 +55,16 @@ class Linear2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -80,15 +82,16 @@ def __init__(self, self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # initialize parameters with seed(ParallelMode.TENSOR): @@ -108,8 +111,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -126,34 +129,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -162,14 +153,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -177,14 +162,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -196,22 +175,53 @@ def forward(self, x: Tensor) -> Tensor: # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) - output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + output = Matmul_AB_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: if self.skip_bias_add: - bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2d( + None, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output, bias else: - output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2d( + output, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output else: return output @@ -249,7 +259,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -266,8 +276,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -283,34 +293,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -319,14 +317,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -334,14 +326,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -349,29 +335,51 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL) - scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + output = layernorm_2d( + x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL + ) + scale = add_bias_2d( + None, + self.weight, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2d( + None, + self.bias, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) output = torch.addcmul(bias, scale, output) else: output = torch.mul(scale, output) @@ -400,16 +408,18 @@ class PatchEmbedding2D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -426,17 +436,22 @@ def __init__(self, with seed(ParallelMode.TENSOR): self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attribute() @@ -457,10 +472,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -484,67 +499,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -552,18 +534,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -573,15 +545,16 @@ def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2d(input_) B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) @@ -623,14 +596,16 @@ class Embedding2D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() assert_summa_initialization() @@ -644,7 +619,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -665,7 +641,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -691,7 +667,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -754,14 +730,16 @@ class VocabParallelEmbedding2D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -778,9 +756,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -796,14 +777,17 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -829,7 +813,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -857,10 +841,11 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) return output @@ -884,14 +869,16 @@ class Classifier2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -908,7 +895,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -938,8 +926,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -957,34 +945,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -995,14 +971,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in row groups @@ -1010,14 +980,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1026,9 +990,21 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes,) - return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + return classifier_2d( + input_, + self.weight, + self.bias, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) @LAYERS.register_module @@ -1050,14 +1026,16 @@ class VocabParallelClassifier2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -1074,13 +1052,14 @@ def __init__(self, self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( - torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) self.has_weight = True # create bias, shape: [h/q] if bias: @@ -1109,8 +1088,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -1128,34 +1107,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -1166,14 +1133,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -1181,14 +1142,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1200,14 +1155,34 @@ def forward(self, x: Tensor) -> Tensor: # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.output_size_per_partition,) - output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = Matmul_ABT_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2d( + output, + self.bias, + self.output_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py index 23e47e6ed06b..46b4d3f3b782 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py @@ -10,6 +10,13 @@ ) __all__ = [ - 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', - 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' + "split_batch_2p5d", + "reduce_by_batch_2p5d", + "Linear2p5D", + "LayerNorm2p5D", + "Classifier2p5D", + "PatchEmbedding2p5D", + "Embedding2p5D", + "VocabParallelClassifier2p5D", + "VocabParallelEmbedding2p5D", ] diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 1226162ae399..50900c135cab 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -24,7 +24,6 @@ def get_parallel_rank(parallel_mode: ParallelMode): class _Classifier2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -98,10 +97,21 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None -def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, - ...], row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, - pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: +def classifier_2p5d( + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Classifier. Args: @@ -123,9 +133,21 @@ def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: T The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) + return _Classifier2p5D.apply( + A, + B, + bias, + tesseract_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class Matmul_AB_2p5D(torch.autograd.Function): @@ -153,16 +175,27 @@ class Matmul_AB_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] # B: [h / dq, s / q] # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -182,14 +215,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opb = [None] * 2 @@ -205,10 +242,9 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple A_list[1 - cur].copy_(A) opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) if opa[cur] is not None: opa[cur].wait() @@ -242,14 +278,36 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -278,13 +336,23 @@ class Matmul_ABT_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -304,14 +372,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opb = [None] * 2 opr = [None] * 2 @@ -323,10 +395,9 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple for i in range(tesseract_dim): if i != tesseract_dim - 1: B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) if opr[cur] is not None: opr[cur].wait() @@ -372,14 +443,36 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_AB_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + output_grad, + A, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -408,13 +501,23 @@ class Matmul_ATB_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int): - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -434,14 +537,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opr = [None] * 2 @@ -499,33 +606,68 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2p5D.apply( + B, + output_grad, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None class _Add_Bias_2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, - row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + def forward( + ctx: Any, + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: if row_rank == 0: bias_temp = bias.clone() else: bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) - src_rank = \ - col_rank + dep_rank * tesseract_dim ** 2 + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_rank = ( + col_rank + + dep_rank * tesseract_dim**2 + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) ctx.row_rank = row_rank @@ -559,43 +701,120 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: tensor_parallel_size = ctx.tensor_parallel_size if ctx.bias: - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return \ - None, output_grad, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + None, + output_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: grad_tmp = torch.zeros_like(output_grad) - return \ - None, grad_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + None, + grad_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: reduce_dim = tuple(range(output_grad.ndim - 1)) reduce = torch.sum(output_grad, dim=reduce_dim) - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return \ - output_grad, reduce, None, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + output_grad, + reduce, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: reduce_tmp = torch.zeros_like(reduce) - return \ - output_grad, reduce_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None, None - - -def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, - col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + return ( + output_grad, + reduce_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def add_bias_2p5d( + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Matrix add bias: :math:`C = A + b`. Args: @@ -618,9 +837,21 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) + return _Add_Bias_2p5D.apply( + input, + bias, + output_size_per_partition, + tesseract_dim, + row_rank, + col_rank, + dep_rank, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Layernorm2p5D(torch.autograd.Function): @@ -640,8 +871,9 @@ class _Layernorm2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode + ) -> Tensor: input = input - E_x # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) ctx.hidden_size = hidden_size @@ -673,8 +905,9 @@ def backward(ctx, output_grad): return input_grad, None, None, None, None, None, None -def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: +def layernorm_2p5d( + input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode +) -> Tensor: r"""Layernorm. Args: @@ -692,7 +925,6 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, class _AllGatherTensor2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: @@ -753,9 +985,9 @@ def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: Par def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) - dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), - output_grad.contiguous(), - group=gpc.get_group(ctx.para_mode)) + dist.all_gather( + list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) + ) return grad, None, None @@ -775,15 +1007,16 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: if world_size <= 1: return input_ - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + ].contiguous() class _ReduceTensor2p5D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -808,7 +1041,6 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2p5D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -834,14 +1066,14 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel """ dim_size = input_.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) class _RreduceByBatch2p5D(torch.autograd.Function): - @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py index 69a350a977ac..8cda15aed2a7 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py @@ -7,19 +7,24 @@ def get_tesseract_dim_dep_from_env(): try: tesseract_dim = env.tesseract_dim tesseract_dep = env.tesseract_dep - assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' - assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' + assert tesseract_dim > 0, "TESSERACT_DIM must be larger than zero" + assert tesseract_dep > 0, "TESSERACT_DEP must be larger than zero" return tesseract_dim, tesseract_dep - except KeyError as e: - raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def assert_tesseract_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ - 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ - 'must be initialized by the process group initializer' + assert ( + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ) + ), ( + "Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ " + "must be initialized by the process group initializer" + ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b4aa9f16ddf0..fc2e35f36cbc 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -56,14 +56,16 @@ class Linear2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -82,15 +84,16 @@ def __init__(self, self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # initialize parameters with seed(ParallelMode.TENSOR): @@ -110,8 +113,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -124,43 +127,33 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # broadcast in dep groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \ - gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + if ( + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 + and gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0 + ): broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) # partition in column groups if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in row groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -169,14 +162,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in column groups @@ -184,14 +171,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -221,16 +202,38 @@ def forward(self, x: Tensor) -> Tensor: if self.bias is not None: if self.skip_bias_add: - bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2p5d( + None, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output, bias else: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, - self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, - False, self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output else: return output @@ -266,10 +269,10 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() # partitioning dimension - self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -286,8 +289,8 @@ def _set_tensor_parallel_attribute(self): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -303,34 +306,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -339,14 +330,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -354,14 +339,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -369,29 +348,51 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) - scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + scale = add_bias_2p5d( + None, + self.weight, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2p5d( + None, + self.bias, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) output = torch.addcmul(bias, scale, output) else: output = torch.mul(scale, output) @@ -420,16 +421,18 @@ class PatchEmbedding2p5D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -446,17 +449,22 @@ def __init__(self, with seed(ParallelMode.TENSOR): self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attribute() @@ -477,10 +485,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -504,67 +512,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -572,18 +547,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -593,15 +558,16 @@ def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2p5d(input_, 0) B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) @@ -643,14 +609,16 @@ class Embedding2p5D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() assert_tesseract_initialization() @@ -664,7 +632,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -685,7 +654,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -711,7 +680,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -775,14 +744,16 @@ class VocabParallelEmbedding2p5D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -799,9 +770,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -817,14 +791,13 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + if self.padding_idx is not None and self.vocab_start_index <= self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -850,7 +823,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -880,11 +853,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) return output @@ -909,14 +883,16 @@ class Classifier2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -934,7 +910,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -964,8 +941,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -983,34 +960,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -1021,14 +986,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in row groups @@ -1036,14 +995,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1052,10 +1005,21 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes,) - return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + return classifier_2p5d( + input_, + self.weight, + self.bias, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) @LAYERS.register_module @@ -1077,14 +1041,16 @@ class VocabParallelClassifier2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -1102,13 +1068,14 @@ def __init__(self, self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( - torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) self.has_weight = True # create bias, shape: [h/q] if bias: @@ -1137,8 +1104,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -1156,27 +1123,15 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) @@ -1203,8 +1158,19 @@ def forward(self, x: Tensor) -> Tensor: ) if self.bias is not None: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output diff --git a/colossalai/legacy/nn/layer/parallel_3d/__init__.py b/colossalai/legacy/nn/layer/parallel_3d/__init__.py index 17fe8403c585..5d38f6a56874 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_3d/__init__.py @@ -10,6 +10,14 @@ ) __all__ = [ - 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', - 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' + "reduce_by_batch_3d", + "split_tensor_3d", + "split_batch_3d", + "Linear3D", + "LayerNorm3D", + "PatchEmbedding3D", + "Classifier3D", + "Embedding3D", + "VocabParallelEmbedding3D", + "VocabParallelClassifier3D", ] diff --git a/colossalai/legacy/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py index c6374efb7124..fe42d8e28111 100755 --- a/colossalai/legacy/nn/layer/parallel_3d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py @@ -16,7 +16,6 @@ class _Linear3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -52,7 +51,8 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) @@ -92,7 +92,6 @@ def linear_3d( class _Classifier3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -131,7 +130,8 @@ def forward( def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors weight_grad = torch.matmul( - output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]) + ) weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) @@ -187,7 +187,6 @@ def classifier_3d( class _VocabParallelClassifier3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -230,7 +229,8 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) @@ -296,7 +296,7 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): # dbias, dweight = grad, grad * mu / sigma dz = grad * weight dmu = dz / sigma - dvar = dz * mu * (-0.5) * sigma**(-3) + dvar = dz * mu * (-0.5) * sigma ** (-3) dmean = -dmu dvar = torch.sum(dvar, -1, keepdim=True) dmean = torch.sum(dmean, -1, keepdim=True) @@ -305,7 +305,6 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): class _Layernorm3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( @@ -415,20 +414,24 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) if tensor.size(dim) <= 1: return tensor - output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), - dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), dim=dim)[ + gpc.get_local_rank(parallel_mode) + ].contiguous() return output -def split_batch_3d(input_: Tensor, - dim: int = 0, - input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: +def split_batch_3d( + input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT, +) -> Tensor: r"""Splits 3D tensor in batch. Args: @@ -456,7 +459,6 @@ def split_batch_3d(input_: Tensor, class _ReduceTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -481,7 +483,6 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: class _AllGatherTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -511,7 +512,6 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) class _ReduceScatterTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -538,21 +538,23 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size})." return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) class _ReduceByBatch3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, - input_: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: + def forward( + ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False, + ) -> Tensor: output = all_reduce(input_, input_parallel_mode) output = all_reduce(output, weight_parallel_mode) ctx.reduce_mean = reduce_mean @@ -571,10 +573,9 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: return output_grad, None, None, None -def reduce_by_batch_3d(tensor: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: +def reduce_by_batch_3d( + tensor: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False +) -> Tensor: r"""All-reduce the input from the model parallel region. Args: diff --git a/colossalai/legacy/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py index cb300c2a9684..8c967da74e67 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_utils.py @@ -18,17 +18,24 @@ def get_depth_from_env() -> int: try: depth = env.depth_3d - assert depth > 0, 'DEPTH must be greater than zero' + assert depth > 0, "DEPTH must be greater than zero" return depth - except KeyError as e: - raise EnvironmentError('DEPTH is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "DEPTH is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def get_parallel_mode_from_env(group): - assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ - f'{group} is not valid for 3D tensor parallelism.' + assert group in [ + INPUT_GROUP_3D, + WEIGHT_GROUP_3D, + OUTPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_X_WEIGHT_3D, + ], f"{group} is not valid for 3D tensor parallelism." return getattr(env, group) @@ -44,12 +51,10 @@ def dbg_check_shape(tensor: Tensor, shape: tuple): rank = gpc.get_global_rank() if rank == 0: print(tensor.shape) - assert tensor.shape == shape, \ - '{} does not match {}'.format(tensor.shape, shape) + assert tensor.shape == shape, "{} does not match {}".format(tensor.shape, shape) class AsyncGradientBucket(object): - def __init__(self): self.bucket = OrderedDict() diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index d6aaa427b9e6..196679994197 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer): """ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): - super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -70,10 +69,12 @@ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=N self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None self.variance_epsilon = eps @@ -94,8 +95,8 @@ def reset_parameters(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -107,15 +108,11 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, + dims={weight_key: 0, bias_key: 0}, partition_states={ weight_key: True, bias_key: True, @@ -130,26 +127,19 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -185,14 +175,16 @@ class Linear3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.out_features = out_features @@ -207,13 +199,17 @@ def __init__(self, self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( - torch.empty(self.in_features_per_partition, - self.out_features_per_partition, - device=get_current_device(), - dtype=dtype)) + torch.empty( + self.in_features_per_partition, + self.out_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -239,15 +235,17 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) self.bias.register_hook(self._sync_grad_hook) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -260,53 +258,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -315,14 +294,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in input groups @@ -330,30 +303,17 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -418,7 +380,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -449,8 +412,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -464,19 +427,12 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: @@ -487,8 +443,8 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -496,19 +452,12 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state[bias_key] = self.bias # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -569,14 +520,18 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.out_features_per_partition, - self.in_features_per_partition, - device=get_current_device(), - dtype=dtype)) + torch.empty( + self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -602,15 +557,17 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -624,53 +581,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -679,14 +617,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in input groups @@ -694,30 +626,17 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -783,15 +704,18 @@ def __init__(self, self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + ) + ) self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attributes() @@ -826,10 +750,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -849,23 +773,12 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[pos_embed_key] = pos_embed # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: @@ -876,47 +789,33 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -979,7 +880,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -996,8 +898,9 @@ def reset_parameters(self, weight_initializer) -> None: fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() - broadcast(self.weight, - gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + broadcast( + self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode + ) self.weight.register_hook(self._sync_grad_hook) def _fill_padding_idx_with_zero(self) -> None: @@ -1007,7 +910,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -1015,8 +918,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[weight_key] = weight # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1032,12 +934,11 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1049,9 +950,9 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -1114,9 +1017,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -1132,14 +1038,17 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -1147,8 +1056,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[weight_key] = weight # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1174,7 +1082,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in weight groups @@ -1195,8 +1103,7 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1218,7 +1125,7 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) return output diff --git a/colossalai/legacy/nn/layer/parallel_sequence/__init__.py b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py index d92d66d40a8e..d64aba6bafe4 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py @@ -1,4 +1,4 @@ from ._operation import RingAV, RingQK from .layers import TransformerSelfAttentionRing -__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] +__all__ = ["TransformerSelfAttentionRing", "RingAV", "RingQK"] diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index ea1863f0b474..24d5499e3a5f 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -25,11 +25,13 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): ctx.sub_seq_length = sub_seq_length # create local segment of attention score - attention_score = torch.empty(batch_size * num_attention_heads, - sub_seq_length, - sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), - dtype=sub_q.dtype, - device=get_current_device()) + attention_score = torch.empty( + batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device(), + ) # compute local QK^T part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) @@ -51,7 +53,10 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): @staticmethod @custom_bwd def backward(ctx, grad_output): - sub_q, sub_k, = ctx.saved_tensors + ( + sub_q, + sub_k, + ) = ctx.saved_tensors local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) @@ -59,7 +64,7 @@ def backward(ctx, grad_output): grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) - grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] + grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length] grad_k /= local_world_size # calculate gradient for sub_q @@ -96,11 +101,13 @@ def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attent local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) - sub_attention_result = torch.zeros(batch_size * num_attention_heads, - sub_seq_length, - attention_head_size, - device=get_current_device(), - dtype=attention_score.dtype) + sub_attention_result = torch.zeros( + batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype, + ) # save tensors for backward ctx.save_for_backward(attention_score, sub_v) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 033c1be962ae..063b0cd8e2b2 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,7 +8,6 @@ import torch.nn.functional as F from torch.nn import Parameter -import colossalai from colossalai.kernel import FusedScaleMaskSoftmax from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed @@ -33,18 +32,20 @@ class TransformerSelfAttentionRing(nn.Module): """ - def __init__(self, - hidden_size, - num_attention_heads, - attention_dropout, - attention_mask_func, - layer_number, - apply_query_key_layer_scaling: bool = False, - convert_fp16_to_fp32_in_softmax: bool = False, - attn_mask_type=AttnMaskType.padding, - masked_softmax_fusion=True, - fp16=False, - bf16=False): + def __init__( + self, + hidden_size, + num_attention_heads, + attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False, + ): super().__init__() self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -59,8 +60,9 @@ def __init__(self, if self.apply_query_key_layer_scaling: self.convert_fp16_to_fp32_in_softmax = True - assert self.hidden_size % self.num_attention_heads == 0, \ - 'hidden size is not divisible by the number of attention heads' + assert ( + self.hidden_size % self.num_attention_heads == 0 + ), "hidden size is not divisible by the number of attention heads" self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads @@ -79,9 +81,15 @@ def __init__(self, self.coeff = layer_number self.norm_factor *= self.coeff - self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, - self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, - self.coeff) + self.scale_mask_softmax = FusedScaleMaskSoftmax( + fp16, + bf16, + self.attn_mask_type, + masked_softmax_fusion, + self.attention_mask_func, + self.convert_fp16_to_fp32_in_softmax, + self.coeff, + ) self.attention_dropout = nn.Dropout(attention_dropout) @@ -102,21 +110,28 @@ def forward(self, hidden_states, attention_mask): mixed_x_layer = self.query_key_value(hidden_states) # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # split into query, key and value last_dim = mixed_x_layer.dim() - 1 last_dim_value = mixed_x_layer.size(-1) - assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ - 'cannot be divided into query, key and value' + assert last_dim_value % 3 == 0, ( + "the last dimension is not a multiple of 3, " "cannot be divided into query, key and value" + ) partition_size = last_dim_value // 3 (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), - key_layer.size(0) * self.world_size) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0) * self.world_size, + ) # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) @@ -125,11 +140,12 @@ def forward(self, hidden_states, attention_mask): # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] attention_scores = RingQK.apply( - query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] - key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], batch_size, self.num_attention_heads, - sub_seq_length) + sub_seq_length, + ) attention_scores /= self.norm_factor @@ -151,12 +167,18 @@ def forward(self, hidden_states, attention_mask): # # change view [b * num_heads, sub_seq_len, seq_len] attention_probs = attention_probs.view( - attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3) + ) # matmul: [batch_size * num_heads, sub_seq_len, head_size] - context_layer = RingAV.apply(attention_probs, - value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, - self.hidden_size_per_attention_head, sub_seq_length) + context_layer = RingAV.apply( + attention_probs, + value_layer.transpose(0, 1).contiguous(), + batch_size, + self.num_attention_heads, + self.hidden_size_per_attention_head, + sub_seq_length, + ) # change view [batch_size, num_heads, sub_seq_len, head_size] context_layer = context_layer.view(*output_size) @@ -165,8 +187,9 @@ def forward(self, hidden_states, attention_mask): context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * - self.num_attention_heads,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_attention_head * self.num_attention_heads, + ) context_layer = context_layer.view(*new_context_layer_shape) output, bias = self.dense(context_layer) @@ -174,11 +197,13 @@ def forward(self, hidden_states, attention_mask): return output, bias def __repr__(self): - return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ - f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ - f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ - f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ - f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' + return ( + f"TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, " + f"layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, " + f"attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, " + f"hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, " + f"convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})" + ) class _Linear(nn.Module): @@ -208,10 +233,12 @@ def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty( - self.output_size, - self.input_size, - )) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) nn.init.xavier_normal_(self.weight) if bias: @@ -220,7 +247,7 @@ def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): # Matrix multiply. @@ -233,5 +260,7 @@ def forward(self, input_): return output def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/colossalai/legacy/nn/layer/utils/__init__.py b/colossalai/legacy/nn/layer/utils/__init__.py index 56e969bfd0bd..4e78b228eb4f 100644 --- a/colossalai/legacy/nn/layer/utils/__init__.py +++ b/colossalai/legacy/nn/layer/utils/__init__.py @@ -10,6 +10,12 @@ ) __all__ = [ - 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', - 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' + "CheckpointModule", + "divide", + "ACT2FN", + "set_tensor_parallel_attribute_by_size", + "set_tensor_parallel_attribute_by_partition", + "get_tensor_parallel_mode", + "_ntuple", + "to_2tuple", ] diff --git a/colossalai/legacy/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py index 3148a0bed570..fd6a5b38d60a 100644 --- a/colossalai/legacy/nn/layer/utils/common.py +++ b/colossalai/legacy/nn/layer/utils/common.py @@ -14,7 +14,6 @@ class CheckpointModule(nn.Module): - def __init__(self, checkpoint: bool = True, offload: bool = False): super().__init__() self.checkpoint = checkpoint @@ -22,7 +21,7 @@ def __init__(self, checkpoint: bool = True, offload: bool = False): self._offload = offload def _forward(self, *args, **kwargs): - raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') + raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward") def forward(self, *args, **kwargs): if self._use_checkpoint: @@ -49,9 +48,8 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator @@ -80,7 +78,6 @@ def get_tensor_parallel_mode(): def _ntuple(n): - def parse(x): if isinstance(x, collections.abc.Iterable): return x diff --git a/colossalai/legacy/nn/layer/vanilla/__init__.py b/colossalai/legacy/nn/layer/vanilla/__init__.py index 3d767b8886f5..5785bbef33d7 100644 --- a/colossalai/legacy/nn/layer/vanilla/__init__.py +++ b/colossalai/legacy/nn/layer/vanilla/__init__.py @@ -9,6 +9,11 @@ ) __all__ = [ - "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", - "VanillaLinear" + "VanillaLayerNorm", + "VanillaPatchEmbedding", + "VanillaClassifier", + "DropPath", + "WrappedDropout", + "WrappedDropPath", + "VanillaLinear", ] diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 71ca1d421de6..12965a4a6409 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -15,7 +15,7 @@ from ..utils import to_2tuple -def drop_path(x, drop_prob: float = 0., training: bool = False): +def drop_path(x, drop_prob: float = 0.0, training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -28,12 +28,12 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): drop_prob (float, optional): probability of dropping path, defaults 0.0. training (bool, optional): whether in training progress, defaults False. """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize + random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output @@ -74,8 +74,7 @@ class WrappedDropout(nn.Module): def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): super().__init__() if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) self.p = p self.inplace = inplace if mode is None: @@ -108,7 +107,7 @@ class WrappedDropPath(nn.Module): in `parallel_mode `_ """ - def __init__(self, p: float = 0., mode=None): + def __init__(self, p: float = 0.0, mode=None): super().__init__() self.p = p self.mode = mode @@ -152,16 +151,18 @@ class VanillaPatchEmbedding(nn.Module): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -172,11 +173,13 @@ def __init__(self, self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + ) self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -188,11 +191,12 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def forward(self, input_: Tensor) -> Tensor: B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -219,14 +223,16 @@ class VanillaClassifier(nn.Module): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -236,7 +242,8 @@ def __init__(self, self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -280,7 +287,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -311,20 +318,22 @@ class VanillaLinear(nn.Module): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/layer/wrapper/__init__.py b/colossalai/legacy/nn/layer/wrapper/__init__.py index c7d90d887ec6..4f3a33645344 100644 --- a/colossalai/legacy/nn/layer/wrapper/__init__.py +++ b/colossalai/legacy/nn/layer/wrapper/__init__.py @@ -1,3 +1,3 @@ from .pipeline_wrapper import PipelineSharedModuleWrapper -__all__ = ['PipelineSharedModuleWrapper'] +__all__ = ["PipelineSharedModuleWrapper"] diff --git a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py index ec19d1b707d8..55445eb4d35a 100644 --- a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py @@ -8,9 +8,8 @@ class PipelineSharedModuleWrapper: - def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: - assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' + assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}" self.pipeline_ranks = pipeline_ranks self.group = None self.ranks_in_group = None @@ -33,16 +32,18 @@ def _init_group(self): self.ranks_in_group = sub_ranks def register_module(self, module: nn.Module): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] for p in module.parameters(): - setattr(p, 'pipeline_shared_module_pg', self.group) + setattr(p, "pipeline_shared_module_pg", self.group) dist.broadcast(p, src, group=self.group) def register_parameter(self, param: nn.Parameter): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] - setattr(param, 'pipeline_shared_module_pg', self.group) + setattr(param, "pipeline_shared_module_pg", self.group) dist.broadcast(param, src, group=self.group) diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py index abb7ec3ef824..43e5a5a2e2aa 100644 --- a/colossalai/legacy/nn/loss/__init__.py +++ b/colossalai/legacy/nn/loss/__init__.py @@ -11,28 +11,27 @@ from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D _parallel_cross_entropy = { - '2d': CrossEntropyLoss2D, - '2.5d': CrossEntropyLoss2p5D, - '3d': CrossEntropyLoss3D, + "2d": CrossEntropyLoss2D, + "2.5d": CrossEntropyLoss2p5D, + "3d": CrossEntropyLoss3D, } _vocab_parallel_cross_entropy = { - '1d': VocabParallelCrossEntropyLoss1D, - '2d': VocabParallelCrossEntropyLoss2D, - '2.5d': VocabParallelCrossEntropyLoss2p5D, - '3d': VocabParallelCrossEntropyLoss3D, + "1d": VocabParallelCrossEntropyLoss1D, + "2d": VocabParallelCrossEntropyLoss2D, + "2.5d": VocabParallelCrossEntropyLoss2p5D, + "3d": VocabParallelCrossEntropyLoss3D, } class CrossEntropyLoss(_Loss): - def __init__(self, reduction: bool = True, *args, **kwargs): super().__init__() tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is not None and env.vocab_parallel: self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) - elif tensor_parallel is None or tensor_parallel == '1d': - reduction = 'mean' if reduction else 'none' + elif tensor_parallel is None or tensor_parallel == "1d": + reduction = "mean" if reduction else "none" self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) else: self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) diff --git a/colossalai/legacy/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py index 2582e8b359d5..fae9c929b788 100644 --- a/colossalai/legacy/nn/loss/loss_1d.py +++ b/colossalai/legacy/nn/loss/loss_1d.py @@ -9,7 +9,6 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, vocab_parallel_logits, targets, process_group): @@ -61,7 +60,6 @@ def forward(ctx, vocab_parallel_logits, targets, process_group): @staticmethod @custom_bwd def backward(ctx, grad_output): - # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors @@ -73,7 +71,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 7ab58415608a..44f39a6db262 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -50,7 +50,7 @@ def forward(self, logits, targets): float: the loss between logits and targets. """ targets = split_batch_2d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2d(loss, True) @@ -69,9 +69,9 @@ def forward(ctx, logits, targets): # vocab_parallel_logits: [b/q, s, v/q] # target: [b/q, s] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW) + ) # Subtract the maximum value. # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) logits = logits - logits_max.unsqueeze(dim=-1) @@ -90,7 +90,7 @@ def forward(ctx, logits, targets): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) exp_logits = torch.exp(logits) @@ -119,7 +119,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index 8a5d04a8c788..c57bf26e9139 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -47,7 +47,7 @@ def forward(self, logits, targets): targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ targets = split_batch_2p5d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2p5d(loss, True) @@ -64,9 +64,9 @@ def forward(ctx, logits, targets): # loss: [b/dq] # targets: [b/dq, h/q] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW) + ) # Subtract the maximum value. logits = logits - logits_max.unsqueeze(dim=-1) @@ -84,7 +84,7 @@ def forward(ctx, logits, targets): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) exp_logits = torch.exp(logits) @@ -113,7 +113,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index a576d84f71cd..988317cae3eb 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -49,7 +49,7 @@ def forward(self, logits, targets): """ targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) targets = split_tensor_3d(targets, 0, self.input_parallel_mode) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) @@ -83,7 +83,7 @@ def forward(ctx, logits, targets, output_parallel_mode): arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) # Loss = log(sum(exp(logits))) - predicted-logit. @@ -111,7 +111,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) return input_grad, None, None, None diff --git a/colossalai/legacy/nn/metric/__init__.py b/colossalai/legacy/nn/metric/__init__.py index 76c6dac89c5b..cc2b2c5d0254 100644 --- a/colossalai/legacy/nn/metric/__init__.py +++ b/colossalai/legacy/nn/metric/__init__.py @@ -8,14 +8,13 @@ from .accuracy_3d import Accuracy3D _parallel_accuracy = { - '2d': Accuracy2D, - '2.5d': Accuracy2p5D, - '3d': Accuracy3D, + "2d": Accuracy2D, + "2.5d": Accuracy2p5D, + "3d": Accuracy3D, } class Accuracy(nn.Module): - def __init__(self): super().__init__() tensor_parallel = get_tensor_parallel_mode() diff --git a/colossalai/legacy/nn/metric/accuracy_2d.py b/colossalai/legacy/nn/metric/accuracy_2d.py index 838c48834a96..59ddd5d66e20 100644 --- a/colossalai/legacy/nn/metric/accuracy_2d.py +++ b/colossalai/legacy/nn/metric/accuracy_2d.py @@ -7,8 +7,7 @@ class Accuracy2D(nn.Module): - """Accuracy for 2D parallelism - """ + """Accuracy for 2D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/legacy/nn/metric/accuracy_2p5d.py b/colossalai/legacy/nn/metric/accuracy_2p5d.py index 183380cd9846..948eae989d48 100644 --- a/colossalai/legacy/nn/metric/accuracy_2p5d.py +++ b/colossalai/legacy/nn/metric/accuracy_2p5d.py @@ -7,8 +7,7 @@ class Accuracy2p5D(nn.Module): - """Accuracy for 2p5D parallelism - """ + """Accuracy for 2p5D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/legacy/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py index 675f5c2b5120..aee6118413ef 100644 --- a/colossalai/legacy/nn/metric/accuracy_3d.py +++ b/colossalai/legacy/nn/metric/accuracy_3d.py @@ -9,8 +9,7 @@ class Accuracy3D(nn.Module): - """Accuracy for 3D parallelism - """ + """Accuracy for 3D parallelism""" def __init__(self): super().__init__() @@ -26,7 +25,7 @@ def forward(self, logits, targets): Returns: float: the accuracy of prediction. - """ + """ with torch.no_grad(): targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) targets = split_tensor_3d(targets, 0, self.input_parallel_mode) diff --git a/colossalai/legacy/nn/parallel/__init__.py b/colossalai/legacy/nn/parallel/__init__.py index 17e010f478c9..19ad8404de18 100644 --- a/colossalai/legacy/nn/parallel/__init__.py +++ b/colossalai/legacy/nn/parallel/__init__.py @@ -1,5 +1,5 @@ from .data_parallel import ColoDDP __all__ = [ - 'ColoDDP', + "ColoDDP", ] diff --git a/colossalai/legacy/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py index 2b2ad36a74f4..9634cb46a12a 100644 --- a/colossalai/legacy/nn/parallel/data_parallel.py +++ b/colossalai/legacy/nn/parallel/data_parallel.py @@ -49,11 +49,13 @@ class ColoDDP(torch.nn.Module): If it's None, the default data parallel group will be used. Defaults to None. """ - def __init__(self, - module: torch.nn.Module, - process_group: ColoProcessGroup, - bucket_cap_mb: int = 25, - rebuild_bucket: bool = True) -> None: + def __init__( + self, + module: torch.nn.Module, + process_group: ColoProcessGroup, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True, + ) -> None: assert not isinstance(module, ColoDDP) super().__init__() self.module = module @@ -74,19 +76,18 @@ def __init__(self, def parameters(self, recurse: bool = True): return self.module.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.module.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix, recurse) def named_children(self): return self.module.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.module.named_modules(memo, prefix, remove_duplicate) def forward(self, *args, **kwargs): @@ -114,9 +115,9 @@ def grad_handle(self, p, grad): grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - self.reducer.all_reduce_async(grad, - group=self.process_group.dp_process_group(), - callback_fn=partial(self._save_grad, p)) + self.reducer.all_reduce_async( + grad, group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p) + ) grad.record_stream(self.comm_stream) else: ColoDDP._save_grad(p, grad) @@ -130,7 +131,7 @@ def grad_handle(self, p, grad): @staticmethod def _save_grad(p, grad): - if hasattr(p, '_saved_grad'): + if hasattr(p, "_saved_grad"): p._saved_grad.add_(grad) else: p._saved_grad = grad @@ -138,7 +139,7 @@ def _save_grad(p, grad): def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) for p in self.module.parameters(): - if getattr(p, '_saved_grad', None) is not None: + if getattr(p, "_saved_grad", None) is not None: if set_to_none: p._saved_grad = None else: @@ -167,8 +168,8 @@ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: for p in params_to_ignore: p._ddp_to_ignore = True - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): return self.module.load_state_dict(state_dict, strict) diff --git a/colossalai/legacy/nn/parallel/layers/__init__.py b/colossalai/legacy/nn/parallel/layers/__init__.py index f38124efedf7..2663076c6992 100644 --- a/colossalai/legacy/nn/parallel/layers/__init__.py +++ b/colossalai/legacy/nn/parallel/layers/__init__.py @@ -14,8 +14,20 @@ from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module __all__ = [ - 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' + "ColoModule", + "register_colo_module", + "is_colo_module", + "get_colo_module", + "init_colo_module", + "check_colo_module", + "ColoLinear", + "ColoEmbedding", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "CachedParamMgr", + "LimitBuffIndexCopyer", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", ] diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py index d87930c1c6b3..aad6dcc5d7d8 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py @@ -7,7 +7,12 @@ from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache __all__ = [ - 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', - 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' + "CachedParamMgr", + "LimitBuffIndexCopyer", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", ] diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py index 9558c541e703..3f825f11fe51 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py @@ -4,17 +4,16 @@ class BaseEmbeddingBag(abc.ABC, nn.Module): - def __init__( self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, - norm_type=2., + norm_type=2.0, scale_grad_by_freq=False, sparse=False, - mode='mean', + mode="mean", include_last_offset=False, ): super(BaseEmbeddingBag, self).__init__() @@ -22,9 +21,9 @@ def __init__( self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py index 16530c4ce7b8..e23864071e66 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -83,15 +83,16 @@ def __init__( if self._async_copy: self._memcpy_stream = torch.cuda.Stream() - print('use async copy') + print("use async copy") if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. - self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(sys.maxsize), - persistent=False) + self.register_buffer( + "freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize), + persistent=False, + ) self._elapsed_dict = {} self._show_cache_miss = True self._reset_comm_stats() @@ -142,10 +143,10 @@ def _init_weight(self, weight): if self.cuda_row_num > 0: # Enable cache with introducing auxiliary data structures self.cuda_cached_weight = torch.nn.Parameter( - torch.zeros(self.cuda_row_num, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=weight.dtype)) + torch.zeros( + self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype + ) + ) # pin memory cpu for higher CPU-GPU copy bandwidth self.weight = weight.pin_memory() if self.pin_weight else weight @@ -158,17 +159,19 @@ def _init_weight(self, weight): ) # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) # cpu_row_id -> gpu_row_idx. # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) @@ -191,9 +194,11 @@ def cpu_weight_data(self, row_idx: int) -> torch.Tensor: torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. """ - return self.weight.data.view(-1).narrow(0, - int(row_idx) * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + return ( + self.weight.data.view(-1) + .narrow(0, int(row_idx) * self.embedding_dim, self.embedding_dim) + .view(1, self.embedding_dim) + ) @property def cuda_available_row_num(self): @@ -238,15 +243,18 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7 preload_cpu_ids = torch.arange(preload_row_num) preload_cuda_row_idxs = preload_cpu_ids.cuda() if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=preload_cpu_ids, - tgt_index=preload_cuda_row_idxs, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=preload_cpu_ids, + tgt_index=preload_cuda_row_idxs, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, - preload_rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_( + 0, preload_cuda_row_idxs, preload_rows + ) # update auxiliary info self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() @@ -260,7 +268,7 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7 else: self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() - print(f'Cache warmup finished cost {timer.elapsed} sec.') + print(f"Cache warmup finished cost {timer.elapsed} sec.") def flush(self): """flush all CUDA rows to CPU. @@ -290,18 +298,18 @@ def print_comm_stats(self): print( f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" ) - print(f'cuda_to_cpu_elapse {elapsed} sec') + print(f"cuda_to_cpu_elapse {elapsed} sec") if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict: elapsed = self._elapsed_dict["5_evict_in"] print( f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" ) - print(f'cpu_to_cuda_elapse {elapsed} sec') + print(f"cpu_to_cuda_elapse {elapsed} sec") for k, v in self._elapsed_dict.items(): - print(f'{k}: {v}') + print(f"{k}: {v}") - print(f'cache miss ratio {self._cache_miss / self._total_cache}') + print(f"cache miss ratio {self._cache_miss / self._total_cache}") @torch.no_grad() def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: @@ -336,10 +344,11 @@ def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: else: cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) - assert len(cpu_row_idxs) <= self.cuda_row_num, \ - f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ - f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + assert len(cpu_row_idxs) <= self.cuda_row_num, ( + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " f"Please increase cuda_row_num or decrease the training batch size." + ) self.evict_backlist = cpu_row_idxs tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True) comm_cpu_row_idxs = cpu_row_idxs[tmp] @@ -386,8 +395,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # move evict in rows to gpu if self._async_copy: if self.buffer_size == 0: - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with torch.cuda.stream(self._memcpy_stream): evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) else: @@ -409,9 +419,10 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # move evict out rows to cpu if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) @@ -425,9 +436,10 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) @@ -438,11 +450,13 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: with self.timer("3_evict_out") as timer: if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=evict_gpu_row_idxs, - tgt_index=evict_info.cpu(), - src=self.cuda_cached_weight.view(self.cuda_row_num, -1), - tgt=self.weight.view(self.num_embeddings, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.weight.view(self.num_embeddings, -1), + ) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. # TODO async gpu -> cpu @@ -450,8 +464,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: _wait_for_data(evict_out_rows_cpu, None) else: with self.timer("3_1_evict_out_index_select") as timer: - evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer: evict_out_rows_cpu = evict_out_rows_cpu.cpu() @@ -469,17 +484,19 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # slots of cuda weight to evict in with self.timer("4_identify_cuda_slot") as timer: - slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[: cpu_row_idxs.numel()] # TODO wait for optimize with self.timer("5_evict_in") as timer: # Here also allocate extra memory on CUDA. #cpu_row_idxs if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=cpu_row_idxs_copy, - tgt_index=slots, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=cpu_row_idxs_copy, + tgt_index=slots, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: if self._async_copy: _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) @@ -488,8 +505,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # narrow index select to a subset of self.weight # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer: evict_in_rows_gpu = evict_in_rows_gpu.cuda() @@ -537,8 +555,9 @@ def _evict(self) -> int: self.cached_idx_map.index_copy_(0, idx, buf) with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) # update inverted_cached_idx, min_slot_id is evicted from cuda @@ -570,8 +589,9 @@ def _admit(self, row_id: int): slot_offset = slot_id # copy payload from cpu to cuda with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) # update the inverted_cached_idx diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py index bc7d178906da..03667857b1ac 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -36,27 +36,38 @@ class CachedEmbeddingBag(BaseEmbeddingBag): evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - max_norm: float = None, - norm_type: float = 2., - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[torch.Tensor] = None, - mode: str = 'mean', - include_last_offset: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - cache_ratio: float = 0.01, - ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, - warmup_ratio: float = 0.7, - buffer_size: int = 0, - pin_weight: bool = False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, - scale_grad_by_freq, sparse, mode, include_last_offset) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[torch.Tensor] = None, + mode: str = "mean", + include_last_offset: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + cache_ratio: float = 0.01, + ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + pin_weight: bool = False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): + super(CachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + mode, + include_last_offset, + ) assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" self.evict_strategy = evict_strategy @@ -78,13 +89,15 @@ def _weight_alloc(self, dtype, device): weight[self.padding_idx].fill_(0) return weight - def _preprocess(self, - weight, - cuda_row_num: int, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False): + def _preprocess( + self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + ): """ Called after initialized. Reorder the weight rows according to the ids_freq_mapping. @@ -95,11 +108,9 @@ def _preprocess(self, ids_freq_mapping (List[int]): a list, idx is id number, value is freq warmup_ratio (float): the amount of rows preloaded in cuda cache """ - self.cache_weight_mgr = CachedParamMgr(weight, - cuda_row_num, - buffer_size, - pin_weight, - evict_strategy=self.evict_strategy) + self.cache_weight_mgr = CachedParamMgr( + weight, cuda_row_num, buffer_size, pin_weight, evict_strategy=self.evict_strategy + ) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): @@ -107,9 +118,19 @@ def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None) with torch.no_grad(): input = self.cache_weight_mgr.prepare_ids(input) - embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) + embeddings = F.embedding_bag( + input.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) if shape_hook is not None: embeddings = shape_hook(embeddings) return embeddings @@ -118,8 +139,8 @@ def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None) def weight(self): return self.cache_weight_mgr.weight - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - yield 'weight', self.cache_weight_mgr.cuda_cached_weight + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield "weight", self.cache_weight_mgr.cuda_cached_weight def parameters(self, recurse: bool = True) -> Iterator[Parameter]: yield self.cache_weight_mgr.cuda_cached_weight @@ -127,8 +148,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def set_cache_op(self, cache_op: bool = True): self.cache_op = cache_op - -############################# Perf Log ################################### + ############################# Perf Log ################################### @property def num_hits_history(self): @@ -145,14 +165,22 @@ def num_write_back_history(self): @property def swap_in_bandwidth(self): if self.cache_weight_mgr._cpu_to_cuda_numel > 0: - return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cpu_to_cuda_elapse + return ( + self.cache_weight_mgr._cpu_to_cuda_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cpu_to_cuda_elapse + ) else: return 0 @property def swap_out_bandwidth(self): if self.cache_weight_mgr._cuda_to_cpu_numel > 0: - return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cuda_to_cpu_elapse + return ( + self.cache_weight_mgr._cuda_to_cpu_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cuda_to_cpu_elapse + ) return 0 diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py index 804a07f88207..5e3a8df05cfe 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py @@ -39,7 +39,7 @@ def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src for begin_pos in range(0, dim_size, self._buff_size): cur_len = min(self._buff_size, dim_size - begin_pos) src_idx_piece = src_index.narrow(0, begin_pos, cur_len) - if src_device.type == 'cpu' and tgt_device.type == 'cuda': + if src_device.type == "cpu" and tgt_device.type == "cuda": cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) tmp_buffer.copy_(cpu_tmp_buffer) diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py index 36e04c833feb..ceaa9081c724 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py @@ -2,22 +2,24 @@ class TablewiseEmbeddingBagConfig: - ''' + """ example: def prepare_tablewise_config(args, cache_ratio, ...): embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] ... return embedding_bag_config_list - ''' + """ - def __init__(self, - num_embeddings: int, - cuda_row_num: int, - assigned_rank: int = 0, - buffer_size=50_000, - ids_freq_mapping=None, - initial_weight: torch.tensor = None, - name: str = ""): + def __init__( + self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = "", + ): self.num_embeddings = num_embeddings self.cuda_row_num = cuda_row_num self.assigned_rank = assigned_rank diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index 522fb4f4497f..ee739935fef2 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -1,13 +1,13 @@ -from typing import Iterator, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec -from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.tensor import ColoTensor -from .cache_mgr import CachedParamMgr, EvictionStrategy +from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag @@ -15,9 +15,9 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: if world_size == 1: return 0, embedding_dim, True - assert embedding_dim >= world_size, \ - f"Embedding dimension {embedding_dim} must be larger than the world size " \ - f"{world_size} of the process group" + assert embedding_dim >= world_size, ( + f"Embedding dimension {embedding_dim} must be larger than the world size " f"{world_size} of the process group" + ) chunk_size = embedding_dim // world_size threshold = embedding_dim % world_size # if embedding dim is divisible by world size @@ -31,37 +31,55 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: class ParallelCachedEmbeddingBag(CachedEmbeddingBag): - - def __init__(self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - ids_freq_mapping=None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() self.partition_start_index, self.partition_end_index, divisible = get_partition( - embedding_dim, self.rank, self.world_size) + embedding_dim, self.rank, self.world_size + ) self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index - super(ParallelCachedEmbeddingBag, - self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) + super(ParallelCachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) self.cache_op = True def _weight_alloc(self, dtype, device): @@ -70,9 +88,11 @@ def _weight_alloc(self, dtype, device): weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) if self.padding_idx is not None: weight[self.padding_idx].fill_(0) - colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), - dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), - compute_attr=ComputePattern.TP1D) + colo_tensor_spec = ColoTensorSpec( + pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D, + ) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) def forward( @@ -87,15 +107,24 @@ def forward( if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) + output_shard = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) if shape_hook is not None: output_shard = shape_hook(output_shard) - output_full = dual_all_to_all(output_shard, - self.weight.get_process_group(), - scatter_dim=scatter_dim, - gather_dim=gather_dim) + output_full = dual_all_to_all( + output_shard, self.weight.get_process_group(), scatter_dim=scatter_dim, gather_dim=gather_dim + ) return output_full def set_cache_op(self, cache_op: bool = True): @@ -108,31 +137,33 @@ def from_pretrained( freeze: bool = True, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, - norm_type: float = 2., + norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - mode: str = 'mean', + mode: str = "mean", include_last_offset: bool = False, cuda_row_num: int = 100_000, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio: float = 0.7, buffer_size: int = 0, - ) -> 'ParallelCachedEmbeddingBag': + ) -> "ParallelCachedEmbeddingBag": rows, cols = embedding.shape - embedding_bag = cls(rows, - cols, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse, - embedding, - mode, - include_last_offset, - cuda_row_num=cuda_row_num, - ids_freq_mapping=ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=buffer_size) + embedding_bag = cls( + rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size, + ) embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze return embedding_bag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index a1feda2bdb0e..7d21f5b68ce6 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -1,4 +1,3 @@ -import time from typing import List import torch @@ -19,24 +18,26 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. """ - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] @@ -62,11 +63,27 @@ def __init__(self, break self.cache_ratio = cache_ratio # table-associate cache - cuda_row_num = int(cache_ratio * self.num_embeddings) - super(ParallelCachedEmbeddingBagTablewise, - self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) + int(cache_ratio * self.num_embeddings) + super(ParallelCachedEmbeddingBagTablewise, self).__init__( + self.num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) # for assigned tables reconnection: self.idx_offset_list = [] @@ -96,7 +113,8 @@ def forward( # not recommanded. it takes time. batch_size = (offsets.shape[0]) // self.global_tables_num local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( - batch_size, indices, offsets, per_sample_weights) + batch_size, indices, offsets, per_sample_weights + ) else: # recommanded. batch_size = (offsets.shape[0]) // len(self.assigned_table_list) @@ -104,9 +122,19 @@ def forward( if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(local_indices) - local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + local_offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + local_per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) local_output = torch.cat(local_output.split(batch_size), 1) remains = batch_size % self.world_size scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] @@ -115,21 +143,19 @@ def forward( output_full = shape_hook(output_full) return output_full - def split_along_rank(self, - batch_size, - indices: torch.Tensor, - offsets: torch.Tensor = None, - per_sample_weights=None): - ''' + def split_along_rank( + self, batch_size, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None + ): + """ if input indices and offsets haven't been splitted along assigned rank, this function will do it. it takes time. please consider splitting data during batch loading. - ''' + """ local_indices_list: List(torch.Tensor) = [] local_offsets_list: List(torch.Tensor) = [] if per_sample_weights != None: local_per_sample_weights_list: List(torch.Tensor) = [] - offset_pre_end = 0 # local_offsets trick + offset_pre_end = 0 # local_offsets trick for i, handle_table in enumerate(self.assigned_table_list): indices_start_position = offsets[batch_size * handle_table] if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): @@ -138,7 +164,7 @@ def split_along_rank(self, else: indices_end_position = offsets[batch_size * (handle_table + 1)] # alternative approach: reduce malloc - ''' + """ # 1. local_indices_list: local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) @@ -158,25 +184,29 @@ def split_along_rank(self, torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder local_offsets_list.append(local_offsets) - ''' + """ # 1. local_indices_list: local_indices_list.append( - indices.narrow(0, indices_start_position, - indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) + indices.narrow(0, indices_start_position, indices_end_position - indices_start_position).sub( + self.idx_offset_list[i] + ) + ) # 2. local_offsets_list: if i + 1 == len(self.assigned_table_list): # till-the-end special case if not self.include_last_offset: - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size * - (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) local_offsets_list.append(local_offsets) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) offset_pre_end = local_offsets[-1] local_offsets_list.append(local_offsets[:-1]) # 3. local_per_sample_weights_list: diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index 8017ee72b0b4..94a27a8673da 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -19,21 +19,23 @@ class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): every table assigned to this class instance is managed by a CachedEmbeddingBag. """ - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - warmup_ratio=0.7, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -56,24 +58,27 @@ def __init__(self, if config.assigned_rank != self.rank: continue self.cached_embedding_bag_list.append( - CachedEmbeddingBag(num_embeddings=config.num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=config.initial_weight, - mode=mode, - include_last_offset=include_last_offset, - dtype=dtype, - device=device, - cuda_row_num=config.cuda_row_num, - ids_freq_mapping=config.ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=config.buffer_size, - pin_weight=pin_weight, - evict_strategy=evict_strategy)) + CachedEmbeddingBag( + num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy, + ) + ) # prepare list shape for all_to_all output self.embedding_dim_per_rank = [0 for i in range(self.world_size)] @@ -95,22 +100,26 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl indices_end_position = offsets[batch_size * (handle_table + 1)] with record_function("part 2"): # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] - local_indices = indices.narrow(0, indices_start_position, indices_end_position - - indices_start_position).sub(self.global_tables_offsets[handle_table]) + local_indices = indices.narrow( + 0, indices_start_position, indices_end_position - indices_start_position + ).sub(self.global_tables_offsets[handle_table]) if self.include_last_offset: # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).sub(offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).sub( + offsets[batch_size * (handle_table)] + ) else: # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).sub(offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).sub( + offsets[batch_size * (handle_table)] + ) local_per_sample_weights = None if per_sample_weights != None: local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] with record_function("(tablewise) tablewise forward"): - local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) + local_output_list.append( + self.cached_embedding_bag_list[i](local_indices, local_offsets, local_per_sample_weights) + ) # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) local_output = torch.cat(local_output_list, 1) diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py index 69d92afaaa94..df0b324eeeb8 100644 --- a/colossalai/legacy/nn/parallel/layers/colo_module.py +++ b/colossalai/legacy/nn/parallel/layers/colo_module.py @@ -5,7 +5,6 @@ class ColoModule(object): - def __init__(self): self._shard_params: List[str] = [] self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} @@ -13,18 +12,18 @@ def __init__(self): def _register_shard_params(self, params: List[str]): self._shard_params = params - def _register_allowed_patterns(self, - compute_pattern: ComputePattern, - dist_specs: Dict[str, _DistSpec], - mode='default'): - assert list( - dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' + def _register_allowed_patterns( + self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default" + ): + assert ( + list(dist_specs.keys()).sort() == self._shard_params.sort() + ), "Every registered param should have dist_spec." if not compute_pattern in self._allowed_patterns: self._allowed_patterns[compute_pattern] = {} self._allowed_patterns[compute_pattern][mode] = dist_specs def _set_default(self, compute_pattern: ComputePattern, target_mode): - self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] + self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode] def has_compute_pattern(self, compute_pattern: ComputePattern): return compute_pattern in self._allowed_patterns @@ -33,10 +32,10 @@ def get_dist_specs(self, compute_pattern: ComputePattern): assert self.has_compute_pattern(compute_pattern) return self._allowed_patterns[compute_pattern] - def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"): return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] - def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"): assert self.has_compute_pattern_with_mode(compute_pattern, mode) return self._allowed_patterns[compute_pattern][mode] diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py index 4796699fc57f..f204f3fb71f0 100644 --- a/colossalai/legacy/nn/parallel/layers/embedding.py +++ b/colossalai/legacy/nn/parallel/layers/embedding.py @@ -1,13 +1,12 @@ -from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec from .colo_module import ColoModule class ColoEmbedding(ColoModule): - def __init__(self): super(ColoEmbedding, self).__init__() - self._register_shard_params(['weight']) + self._register_shard_params(["weight"]) def register(self, compute_pattern, pg: ProcessGroup): if not compute_pattern in self._allowed_patterns: @@ -20,18 +19,18 @@ def _set_TP1D(self, pg: ProcessGroup): self._register_allowed_patterns( compute_pattern=_compute_pattern, dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), + "weight": ShardSpec([0], [pg.tp_world_size()]), }, - mode='row', + mode="row", ) # TP1D Col Linear self._register_allowed_patterns( compute_pattern=_compute_pattern, dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), + "weight": ShardSpec([-1], [pg.tp_world_size()]), }, - mode='col', + mode="col", ) - self._set_default(compute_pattern=_compute_pattern, target_mode='row') + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py index 51a8d4c976a6..c3b6df1ec9da 100644 --- a/colossalai/legacy/nn/parallel/layers/linear.py +++ b/colossalai/legacy/nn/parallel/layers/linear.py @@ -1,13 +1,12 @@ -from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec from .colo_module import ColoModule class ColoLinear(ColoModule): - def __init__(self): super(ColoLinear, self).__init__() - self._register_shard_params(['weight', 'bias']) + self._register_shard_params(["weight", "bias"]) def register(self, compute_pattern, pg: ProcessGroup): if not compute_pattern in self._allowed_patterns: @@ -19,21 +18,15 @@ def _set_TP1D(self, pg): _compute_pattern = ComputePattern.TP1D self._register_allowed_patterns( compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), - 'bias': None - }, - mode='row', + dist_specs={"weight": ShardSpec([-1], [pg.tp_world_size()]), "bias": None}, + mode="row", ) # TP1D Col Linear self._register_allowed_patterns( compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), - 'bias': ShardSpec([0], [pg.tp_world_size()]) - }, - mode='col', + dist_specs={"weight": ShardSpec([0], [pg.tp_world_size()]), "bias": ShardSpec([0], [pg.tp_world_size()])}, + mode="col", ) - self._set_default(compute_pattern=_compute_pattern, target_mode='row') + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/legacy/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py index 09326d2d6f9a..4dbce7e09f37 100644 --- a/colossalai/legacy/nn/parallel/layers/module_utils.py +++ b/colossalai/legacy/nn/parallel/layers/module_utils.py @@ -2,7 +2,7 @@ import torch -from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec +from colossalai.legacy.tensor import ComputeSpec, ProcessGroup from colossalai.tensor import ColoParameter from . import ColoModule @@ -41,7 +41,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) for param_name in param_names: param = module.get_parameter(param_name) if not isinstance(param, ColoParameter): - raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + raise Exception(f"Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.") if param.has_compute_spec(): cur_compute_pattern = param.compute_spec.compute_pattern if compute_pattern is None: @@ -49,7 +49,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) else: if cur_compute_pattern != compute_pattern: raise Exception( - f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + f"Invalid ColoParameter spec: Params in {module} have different compute_pattern." + ) else: continue @@ -57,7 +58,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) colo_module.register(compute_pattern, pg) if not colo_module.has_compute_pattern(compute_pattern): raise Exception( - f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + f"Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed." + ) match_specs = False allowed_specs = colo_module.get_dist_specs(compute_pattern) @@ -77,17 +79,15 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) match_specs = True break if match_specs == False: - raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + raise Exception(f"Invalid ColoParameter spec: Params in {module} are incorrectly sharded.") if recursive == True: for submodule in module.children(): check_colo_module(submodule, pg=pg, recursive=True) -def init_colo_module(module: torch.nn.Module, - compute_spec: ComputeSpec, - pg: ProcessGroup, - recursive=True, - mode='default'): +def init_colo_module( + module: torch.nn.Module, compute_spec: ComputeSpec, pg: ProcessGroup, recursive=True, mode="default" +): compute_pattern = compute_spec.compute_pattern if is_colo_module(module): # for each param diff --git a/colossalai/legacy/nn/parallel/reducer.py b/colossalai/legacy/nn/parallel/reducer.py index 5687055819fe..7b3d283e47dd 100644 --- a/colossalai/legacy/nn/parallel/reducer.py +++ b/colossalai/legacy/nn/parallel/reducer.py @@ -13,7 +13,6 @@ class Bucket: - def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros(size, dtype=dtype, device=device) self.group = group @@ -26,7 +25,7 @@ def flush(self) -> None: assert len(self.callbacks) == 0 return # reduce-scatter bucket - dist.all_reduce(self.buffer[:self.offset], group=self.group) + dist.all_reduce(self.buffer[: self.offset], group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: @@ -37,24 +36,22 @@ def flush(self) -> None: self.buffer = torch.zeros_like(self.buffer) def alloc(self) -> None: - if self.buffer.storage().size() == 0: self.buffer.storage().resize_(self.buffer.numel()) def free(self) -> None: - assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" self.buffer.storage().resize_(0) def append(self, tensor: Tensor, callback_fn: Callable): tensor_size = tensor.numel() offset = self.offset - self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.buffer[offset : offset + tensor_size].copy_(tensor.flatten()) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape) self.callbacks.append(functools.partial(callback_fn, result_view)) @property @@ -63,7 +60,6 @@ def avail_size(self) -> int: class Reducer: - def __init__(self, bucket_size_mb: int = 25): self.bucket_size_mb = bucket_size_mb self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} @@ -101,7 +97,7 @@ def free(self) -> None: @functools.lru_cache() def _get_bucket_size(self, element_size: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/pipeline/__init__.py b/colossalai/legacy/pipeline/__init__.py index f36f54ac9307..9f1a5ec7fd1f 100644 --- a/colossalai/legacy/pipeline/__init__.py +++ b/colossalai/legacy/pipeline/__init__.py @@ -1,4 +1,4 @@ from .layer_spec import LayerSpec from .pipelinable import PipelinableContext, PipelinableModel -__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] +__all__ = ["PipelinableModel", "PipelinableContext", "LayerSpec"] diff --git a/colossalai/legacy/pipeline/layer_spec.py b/colossalai/legacy/pipeline/layer_spec.py index 3960debd7f72..825816e1c032 100644 --- a/colossalai/legacy/pipeline/layer_spec.py +++ b/colossalai/legacy/pipeline/layer_spec.py @@ -4,9 +4,7 @@ class LayerSpec: - """ - - """ + """ """ def __init__(self, typename, *module_args, **module_kwargs): self.typename = typename @@ -16,7 +14,7 @@ def __init__(self, typename, *module_args, **module_kwargs): self._param_count = 0 if not issubclass(typename, torch.nn.Module): - raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + raise RuntimeError("LayerSpec only supports torch.nn.Module types.") def __repr__(self): return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) diff --git a/colossalai/legacy/pipeline/middleware/__init__.py b/colossalai/legacy/pipeline/middleware/__init__.py index 481741bfee31..8a678b7b4c87 100644 --- a/colossalai/legacy/pipeline/middleware/__init__.py +++ b/colossalai/legacy/pipeline/middleware/__init__.py @@ -1,3 +1,3 @@ from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo -__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] +__all__ = ["Topo", "Partition", "PartitionOutputVal", "PartitionInputVal"] diff --git a/colossalai/legacy/pipeline/middleware/adaptor/__init__.py b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py index 0b0d36d2ffe5..7f2b18670a76 100644 --- a/colossalai/legacy/pipeline/middleware/adaptor/__init__.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py @@ -1,3 +1,3 @@ from .fx import get_topology as get_fx_topology -__all__ = ['get_fx_topology'] +__all__ = ["get_fx_topology"] diff --git a/colossalai/legacy/pipeline/middleware/adaptor/fx.py b/colossalai/legacy/pipeline/middleware/adaptor/fx.py index 8cc40f120f15..34b21f8be1bb 100644 --- a/colossalai/legacy/pipeline/middleware/adaptor/fx.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/fx.py @@ -10,7 +10,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): elif is_output: partition_id = 1 else: - prefix = 'submod_' + prefix = "submod_" partition_id = int(partition_name.split(prefix)[-1]) + 2 return partition_id @@ -27,10 +27,10 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): def find_input_in_partition(node, partitions, input_partitions=None): p_input_val = None - direct_def = not node.name.startswith('getitem') + direct_def = not node.name.startswith("getitem") # search in input if direct_def and input_partitions is not None: - partition_id = partition_name_to_id('', is_input=True) + partition_id = partition_name_to_id("", is_input=True) for i, input_node in enumerate(input_partitions): if input_node == node: p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) @@ -57,7 +57,7 @@ def find_input_in_partition(node, partitions, input_partitions=None): def find_output_in_partition(node, partitions, output_partitions=None): p_output_val = PartitionOutputVal() for user in node.users: - direct_use = not user.name.startswith('getitem') + direct_use = not user.name.startswith("getitem") # user is mid partition for partition in partitions: # direct call @@ -82,7 +82,7 @@ def find_output_in_partition(node, partitions, output_partitions=None): output_node = output_partitions[0] if user.op == output_node.op: output_keys = {} - partition_id = partition_name_to_id('', is_output=True) + partition_id = partition_name_to_id("", is_output=True) torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) for i, arg in enumerate(output_keys): if arg == node: @@ -99,11 +99,11 @@ def get_topology(gm: GraphModule): partitions = [] output_partitions = [] for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": input_partitions.append(node) - elif node.name.startswith('submod_'): + elif node.name.startswith("submod_"): partitions.append(node) - elif node.op == 'output': + elif node.op == "output": output_partitions.append(node) else: continue @@ -127,7 +127,7 @@ def get_topology(gm: GraphModule): # set output for submodule direct_use = True for user in partition.users: - if user.name.startswith('getitem'): + if user.name.startswith("getitem"): direct_use = False break if direct_use: @@ -146,7 +146,8 @@ def get_topology(gm: GraphModule): topo_output_partition = Partition() torch.fx.graph.map_arg( partition.args[0], - lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions))) + lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)), + ) topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_output_partition_id(partition_id=1) diff --git a/colossalai/legacy/pipeline/middleware/topo.py b/colossalai/legacy/pipeline/middleware/topo.py index 3c21cce6dc0e..d0e3d2c3dedf 100644 --- a/colossalai/legacy/pipeline/middleware/topo.py +++ b/colossalai/legacy/pipeline/middleware/topo.py @@ -10,7 +10,7 @@ class ValPosition: offset: int def __str__(self) -> str: - res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + res = f"[partition_id:{self.partition_id},offset:{self.offset}]" return res def __repr__(self) -> str: @@ -18,7 +18,6 @@ def __repr__(self) -> str: class PartitionInputVal(object): - def __init__(self, partition_id, offset) -> None: # every input from which partition_id and which offset val_pos = ValPosition(partition_id, offset) @@ -28,8 +27,8 @@ def get(self): return self._from_partition_and_offset def __str__(self) -> str: - res = '' - res += f'<-({self._from_partition_and_offset})' + res = "" + res += f"<-({self._from_partition_and_offset})" return res def __repr__(self) -> str: @@ -37,7 +36,6 @@ def __repr__(self) -> str: class PartitionOutputVal(object): - def __init__(self) -> None: # every output to which partition_id and which offset self._to_partition_and_offset: List[ValPosition] = [] @@ -50,11 +48,11 @@ def get(self): return self._to_partition_and_offset def __str__(self) -> str: - res = '' - res += '->(' + res = "" + res += "->(" for val_pos in self._to_partition_and_offset: - res += f'{val_pos},' - res += ')' + res += f"{val_pos}," + res += ")" return res def __repr__(self) -> str: @@ -62,7 +60,6 @@ def __repr__(self) -> str: class Partition(object): - def __init__(self) -> None: self._input_vals: List[PartitionInputVal] = [] self._output_vals: List[PartitionOutputVal] = [] @@ -110,16 +107,16 @@ def get_output_partition_ids(self): return res def __str__(self) -> str: - res = '' - res += f' input:\n' - res += f' length:{len(self._input_vals)}\n' + res = "" + res += f" input:\n" + res += f" length:{len(self._input_vals)}\n" for i, input_val in enumerate(self._input_vals): - res += f' offset={i}:{input_val}\n' + res += f" offset={i}:{input_val}\n" - res += f' output:\n' - res += f' length:{len(self._output_vals)}\n' + res += f" output:\n" + res += f" length:{len(self._output_vals)}\n" for i, output_val in enumerate(self._output_vals): - res += f' offset={i}:{output_val}\n' + res += f" offset={i}:{output_val}\n" return res @@ -140,7 +137,6 @@ def __repr__(self) -> str: # _input_partition_id: the key represents input_partition # _output_partition_id: the key represents output_partition class Topo(object): - def __init__(self, input_partition_id=None, output_partition_id=None) -> None: self._partitions: Dict[int, Partition] = {} self._input_partition_id = input_partition_id @@ -162,7 +158,7 @@ def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition def get_mid_partitions(self): - res = {} #{partition_id: Partition} + res = {} # {partition_id: Partition} for partition_id, partition in self._partitions.items(): if self._input_partition_id == partition_id or self._output_partition_id == partition_id: continue @@ -186,27 +182,27 @@ def get_partition_by_id(self, partition_id): return self._partitions[partition_id] def __str__(self) -> str: - res = '' + res = "" if len(self._partitions) == 0: - return 'Empty Topo Graph.' + return "Empty Topo Graph." input_part = self.get_input_partition() if input_part is not None: - res += '{\n' - res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' - res += '}\n' + res += "{\n" + res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}" + res += "}\n" mid_parts = self.get_mid_partitions() for i, (partition_id, part) in enumerate(mid_parts.items()): - res += '{\n' - res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' - res += '}\n' + res += "{\n" + res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}" + res += "}\n" output_part = self.get_output_partition() if output_part is not None: - res += '{\n' - res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' - res += '}\n' + res += "{\n" + res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}" + res += "}\n" return res diff --git a/colossalai/legacy/pipeline/pipelinable.py b/colossalai/legacy/pipeline/pipelinable.py index e74cad0ad1b0..82ccdb554527 100644 --- a/colossalai/legacy/pipeline/pipelinable.py +++ b/colossalai/legacy/pipeline/pipelinable.py @@ -132,8 +132,8 @@ def to_layer_list(self, exec_seq=None): for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] if layer_spec.typename in ( - torch.nn.modules.container.ModuleList, - torch.nn.modules.container.Sequential, + torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential, ): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) @@ -198,8 +198,9 @@ def partition(self, num_chunks, pipeline_size, rank): param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] elif self._policy == "customized": - assert (self._exec_seq - is not None), f"An explicit exec_seq must be defined by user in customized policy mode." + assert ( + self._exec_seq is not None + ), f"An explicit exec_seq must be defined by user in customized policy mode." self.customized_parts = customized_partition(self._exec_seq) assert len(self.customized_parts) == gpc.get_world_size( ParallelMode.PIPELINE @@ -226,14 +227,14 @@ def partition(self, num_chunks, pipeline_size, rank): elif (layer, "behind") in self._func_dict: behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) - pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, - behind_func_dict_in_partition) + pipeline_model = PipelinableModel( + module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition + ) return pipeline_model class PipelinableModel(torch.nn.Module): - def __init__(self, module_list, front_func_dict, behind_func_dict): super().__init__() self._module_list = module_list diff --git a/colossalai/legacy/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py index 1168158defaf..2d0d5be87cac 100644 --- a/colossalai/legacy/pipeline/pipeline_process_group.py +++ b/colossalai/legacy/pipeline/pipeline_process_group.py @@ -1,6 +1,5 @@ -import os import threading -from typing import Dict, List, Tuple +from typing import List import torch.distributed as dist from torch.distributed import rpc @@ -14,14 +13,15 @@ class PipelineProcessGroup: def __init__(self) -> None: self.is_initialize = False - def set_global_info(self, - rank: int, - world_size: int, - dp_degree: int = 1, - tp_degree: int = 1, - num_worker_threads: int = 1, - device: str = "cuda") -> None: - + def set_global_info( + self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda", + ) -> None: device_mesh_size = dp_degree * tp_degree assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" self._num_worker_threads = num_worker_threads @@ -60,8 +60,8 @@ def _initialize_process_group(self): device = self.device world_size = self.get_world_size() rank = self.get_global_rank() - backend = 'nccl' if device == 'cuda' else 'gloo' - dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + backend = "nccl" if device == "cuda" else "gloo" + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name="main_group") def _initialize_pp_process_group(self) -> None: rank = self.get_global_rank() @@ -71,9 +71,9 @@ def _initialize_pp_process_group(self) -> None: options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) for pp_rank in self._pp_ranks: - options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + options.set_device_map(f"work{pp_rank}", {rank: pp_rank}) - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + rpc.init_rpc(name=f"work{rank}", rank=rank, world_size=world_size, rpc_backend_options=options) def _initialize_tp_dp_process_group(self) -> None: rank = self.get_global_rank() @@ -147,10 +147,10 @@ def get_tp_global_ranks(self): def get_chimera_all_reduce_group(self, pp_rank: int): with self.chimera_lock: - if not hasattr(self, 'chimera_groups'): + if not hasattr(self, "chimera_groups"): world_size = self.get_world_size() stage_num = self.get_stage_num() - assert world_size % 2 == 0, 'world_size must be even in chimera!' + assert world_size % 2 == 0, "world_size must be even in chimera!" self.chimera_groups = {} for rank in range(world_size // 2): pair = [rank, world_size - 1 - rank] diff --git a/colossalai/legacy/pipeline/rpc/__init__.py b/colossalai/legacy/pipeline/rpc/__init__.py index 15b65a4138a8..791b9d530673 100644 --- a/colossalai/legacy/pipeline/rpc/__init__.py +++ b/colossalai/legacy/pipeline/rpc/__init__.py @@ -1,4 +1,4 @@ from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine from .utils import pytree_map -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] +__all__ = ["FillDrainPipelineEngine", "OneFOneBPipelineEngine", "ChimeraPipelineEngine", "pytree_map"] diff --git a/colossalai/legacy/pipeline/rpc/_pipeline_base.py b/colossalai/legacy/pipeline/rpc/_pipeline_base.py index 88ddb9e98eb2..d203e1a11180 100644 --- a/colossalai/legacy/pipeline/rpc/_pipeline_base.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_base.py @@ -12,17 +12,9 @@ from torch._C._distributed_rpc import PyRRef from torch.futures import Future -from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.middleware import Partition, Topo from colossalai.legacy.pipeline.pipeline_process_group import ppg -from colossalai.legacy.pipeline.rpc.utils import ( - get_batch_lengths, - pyobj_map, - pytree_filter, - pytree_map, - split_batch, - tensor_shape_list, - type_detail, -) +from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch class Phase(Enum): @@ -33,7 +25,7 @@ class Phase(Enum): class UniqueKey: - __slots__ = ('microbatch_id', 'phase') + __slots__ = ("microbatch_id", "phase") microbatch_id: int phase: Phase @@ -48,12 +40,22 @@ def __hash__(self) -> int: return tuple.__hash__((self.microbatch_id, self.phase)) def __repr__(self) -> str: - return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})" class WorkItem: - __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', - 'num_microbatches', 'forward_only') + __slots__ = ( + "stage_id", + "phase", + "args", + "kwargs", + "output", + "refcount", + "microbatch_id", + "batch_id", + "num_microbatches", + "forward_only", + ) stage_id: int phase: Phase @@ -66,50 +68,45 @@ class WorkItem: num_microbatches: int forward_only: bool - def __init__(self, - stage_id, - phase, - args, - kwargs, - output, - microbatch_id, - batch_id, - num_microbatches, - forward_only, - refcount=0) -> None: + def __init__( + self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0 + ) -> None: for attr_name in self.__slots__: setattr(self, attr_name, locals()[attr_name]) class BackwardCache: - __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') + __slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs") checkpoint: bool stage_input_args: Tuple[Any] stage_input_kwargs: Dict[Any, Any] stage_outputs: Tuple[Any] - def __init__(self, - stage_input_args: Tuple[Any], - stage_input_kwargs: Dict[Any, Any] = None, - stage_outputs: Tuple[Any] = None, - checkpoint: bool = False) -> None: + def __init__( + self, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, + checkpoint: bool = False, + ) -> None: for arg_name in self.__slots__: setattr(self, arg_name, locals()[arg_name]) class WorkerBase(ABC): - - def __init__(self, - partition_fn: Callable, - partition_args: tuple, - pp_rank: int, - actual_stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + partition_fn: Callable, + partition_args: tuple, + pp_rank: int, + actual_stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.pp_rank = pp_rank @@ -150,11 +147,11 @@ def __init__(self, self._initialize_context_container() # main loop - self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True) self.main_loop_thread.start() def _get_future_by_device(self): - return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device]) def _initialize_outstanding_range(self): outstanding_range = None @@ -199,12 +196,13 @@ def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None): # lifecycle management for DAG scheduler if output_work_item.phase == Phase.FORWARD: lifecycle = len(self.get_consumer_stage_ids()) - if self.is_model_output(): # an extra reference for scheduler collecting results + if self.is_model_output(): # an extra reference for scheduler collecting results lifecycle += 1 elif output_work_item.phase == Phase.BACKWARD: lifecycle = len(self.get_producer_stage_ids()) if self.is_model_input() and self._is_last_step( - output_work_item): # an extra reference for ensure_backward + output_work_item + ): # an extra reference for ensure_backward lifecycle += 1 else: lifecycle = 0 @@ -234,9 +232,9 @@ def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> # offset supports get partial output to reduce comm costs. def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: output = self._get_output_all(key, ref_use, rank) - if offsets is None: # get all for non iterable output + if offsets is None: # get all for non iterable output return output - else: # get part for iterable output + else: # get part for iterable output output = [output[i] for i in offsets] return output @@ -252,12 +250,12 @@ def get_parameter_gradients(self) -> List[torch.Tensor]: def get_partition(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition def get_partition_state_dict(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition.state_dict() def _make_args_kwargs(self, microbatch, merge=False): @@ -293,8 +291,17 @@ def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bo # make args and kwargs args, kwargs = self._make_args_kwargs(microbatch) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + args, + kwargs, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -314,15 +321,25 @@ def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bo for off in self_input_offsets: self_arg_lst.append(arg_lst[off]) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + self_arg_lst, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() # put input tensor which other nodes need into output_list as Phase.INPUT - work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, - self.num_microbatches, forward_only) + work_item_remote = WorkItem( + self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only + ) with self.output_list_condition_lock: self.output_list[recv_input_key] = work_item_remote @@ -343,8 +360,17 @@ def _begin_backward(self, microbatch_id: int): output = self._get_future_by_device() grad_wrt_loss = None - work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, - self.num_microbatches, False) + work_item = WorkItem( + self.pp_rank, + Phase.BACKWARD, + grad_wrt_loss, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -367,7 +393,7 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) if self.need_model_input(): - producer_num += 1 # for input partition + producer_num += 1 # for input partition subscribe_forward_futures: List[Future] = [None] * producer_num # TODO(jiangziyue) get single value instead of the whole output @@ -376,9 +402,9 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] offsets = self._get_input_offsets_by_index(target_index=0) - subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, - rank=self.pp_rank, - offsets=offsets) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets + ) for i in range(0, producer_num - 1): producer_stage_id = producer_stage_ids[i] @@ -386,11 +412,12 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i + 1 offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) else: for i in range(producer_num): @@ -399,14 +426,24 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) - - work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, - microbatch_id, None, self.num_microbatches, forward_only) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) + + work_item_from_producer = WorkItem( + stage_id, + Phase.FORWARD, + subscribe_forward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) return work_item_from_producer @@ -441,15 +478,25 @@ def _subscribe_consumer(self, microbatch_id: int): consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] target_index = i offsets = self._get_output_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_backward_futures[target_index] = [] else: subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( - consumer_output_key, rank=self.pp_rank, offsets=offsets) + consumer_output_key, rank=self.pp_rank, offsets=offsets + ) # flatten args - work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, - microbatch_id, None, self.num_microbatches, False) + work_item_from_consumer = WorkItem( + stage_id, + Phase.BACKWARD, + subscribe_backward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) return work_item_from_consumer @@ -524,8 +571,8 @@ def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): def get_topo(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) - if hasattr(self.module_partition, '_topo'): + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) + if hasattr(self.module_partition, "_topo"): return self.module_partition._topo else: return None @@ -564,12 +611,12 @@ def _get_input_offsets_by_index(self, target_index): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if target_index == src_index: if output_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(src_offset) @@ -584,7 +631,6 @@ def _get_output_offsets_by_index(self, target_index): consumer_stage_ids = self.get_consumer_stage_ids() for val_list in output_vals: # An output may be passed to many down stages. - target = None for val_pos in val_list.get(): dst_partition_id = val_pos.partition_id dst_offset = val_pos.offset @@ -597,7 +643,7 @@ def _get_output_offsets_by_index(self, target_index): break if target_index == dst_index: if input_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(dst_offset) @@ -623,7 +669,7 @@ def _get_real_args_kwargs_fwd(self, args_or_kwargs): flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) - else: # get by offset + else: # get by offset topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) self_partition: Partition = topo.get_partition_by_id(self_partition_id) @@ -652,7 +698,7 @@ def _get_real_args_kwargs_fwd(self, args_or_kwargs): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if output_len == 1: @@ -679,7 +725,7 @@ def _get_real_args_kwargs_bwd(self, args_or_kwargs): else: for i, arg in enumerate(args_or_kwargs): args_or_kwargs[i] = arg.wait() - if args_or_kwargs is not None: # get by offset + if args_or_kwargs is not None: # get by offset flatten_args = [] topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) @@ -719,7 +765,7 @@ def _get_real_args_kwargs_bwd(self, args_or_kwargs): @abstractmethod def _get_work_item_key(self) -> UniqueKey: """ - this method control the order of the microbatch to consume + this method control the order of the microbatch to consume """ def is_first_stage(self): @@ -761,7 +807,7 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): kwargs = work_item.kwargs microbatch_id = work_item.microbatch_id forward_only = work_item.forward_only - data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) + data_process_func = getattr(self, "data_process_func", self._default_data_process_func) consume_result = None is_first_stage = self.is_first_stage() @@ -787,10 +833,12 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): else: args_kwargs = self._get_real_args_kwargs_fwd(args) - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, - process_types=torch.device) # change devices from last stage to current device + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: self.device, process_types=torch.device + ) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs) @@ -851,16 +899,16 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): use_checkpoint = False if not forward_only: - self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, - stage_input_kwargs, - stage_outputs, - checkpoint=use_checkpoint) - consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache( + stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint + ) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in # if not forward_only, do the backward if not forward_only: - if is_last_stage: # if it is the last stage, trigger backward automatic + if is_last_stage: # if it is the last stage, trigger backward automatic self._begin_backward(microbatch_id) elif phase == Phase.BACKWARD: @@ -872,7 +920,9 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): self.backward_times += 1 self.outstanding -= 1 - assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + assert ( + microbatch_id in self.microbatch_id_to_backward_cache + ), f"microbatch_id {microbatch_id} not in backward cache" backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) stage_outputs = backward_cache.stage_outputs @@ -906,8 +956,9 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): filtered_grads.append(grad) stage_outputs = filtered_outputs - grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + grad_tensors = pyobj_map( + filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor @@ -920,8 +971,8 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): else: consume_result.append(None) consume_result = pyobj_map( - consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -929,7 +980,7 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): return consume_result def _get_store_len(self): - return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}" def _get_parameter_grad_sum(self): grad_sum = 0 @@ -1014,19 +1065,20 @@ def step(self): class PipelineEngineBase(ABC, nn.Module): - - def __init__(self, - worker_type, - partition_fn: Callable, - stage_num, - num_microbatches, - device: str, - use_1F1B=False, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + worker_type, + partition_fn: Callable, + stage_num, + num_microbatches, + device: str, + use_1F1B=False, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.worker_type = worker_type self.partition_fn: Callable = partition_fn @@ -1056,12 +1108,12 @@ def _check_argument(self) -> None: data_process_func = self.data_process_func if data_process_func is not None: assert callable(data_process_func), "data_process_func must be a function" - assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" - assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + assert "" not in data_process_func.__repr__(), "data_process_func must be a global function" + assert "" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" sig = inspect.signature(data_process_func) - assert len( - sig.parameters - ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + assert ( + len(sig.parameters) == 2 + ), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num @@ -1104,19 +1156,33 @@ def _init_worker(self) -> None: partition_id = self.pp_rank_to_module_partition_id[pp_rank] partition_args = (partition_id, chunk, actual_stage_num) rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] - if device[:4] == 'cuda': - device = f'cuda:{rpc_worker_id}' - self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, - worker_type, - args=(partition_fn, partition_args, pp_rank, - actual_stage_num, num_microbatches, device, - criterion, metric, checkpoint, data_process_func)) + if device[:4] == "cuda": + device = f"cuda:{rpc_worker_id}" + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote( + rpc_worker_id, + worker_type, + args=( + partition_fn, + partition_args, + pp_rank, + actual_stage_num, + num_microbatches, + device, + criterion, + metric, + checkpoint, + data_process_func, + ), + ) # let each worker know global worker rref (include itself) sync_futs = [] for pp_rank in self.pp_rank_to_worker_rref: - fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs( - self.pp_rank_to_worker_rref) + fut = ( + self.pp_rank_to_worker_rref[pp_rank] + .rpc_async(timeout=0) + .sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + ) sync_futs.append(fut) for fut in sync_futs: @@ -1157,8 +1223,9 @@ def get_input_pp_ranks(self) -> List[int]: def get_output_pp_ranks(self) -> List[int]: return [self._get_actual_stage_num() - 1] - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): actual_stage_num = self._get_actual_stage_num() use_1F1B = self.use_1F1B if microbatch_id >= actual_stage_num: @@ -1206,7 +1273,8 @@ def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): worker_rref = self.pp_rank_to_worker_rref[pp_rank] key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) fut = worker_rref.rpc_async().get_output_by_key( - key, offsets=[]) # only ensure the res exists, no need for real data. + key, offsets=[] + ) # only ensure the res exists, no need for real data. backward_result.append(fut) for fut in backward_result: @@ -1244,11 +1312,14 @@ def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, for if labels is not None and not forward_only: assert hasattr( - self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" + self, "optimizer_class" + ), "call `initialize_optimizer` to initialize optimizer before forward_backward" num_microbatches = self.num_microbatches - assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + assert ( + batch_length >= num_microbatches + ), "num_microbatches is greater than the size of a batch, which is illegal" microbatch_size = math.ceil(batch_length / num_microbatches) device = self.device @@ -1285,10 +1356,10 @@ def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, for # collect forward result forward_result = self._collect_forward_result(output_pp_ranks, ret_future) - if not forward_only and hasattr(self, 'optimizer_class'): + if not forward_only and hasattr(self, "optimizer_class"): self.step() - self._reset_worker() # reset worker attributes for next batch + self._reset_worker() # reset worker attributes for next batch return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): diff --git a/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py index f53a4835edf2..56da2a954225 100644 --- a/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py @@ -2,7 +2,6 @@ from typing import Callable, Dict, List import torch -import torch.distributed as dist from torch._C._distributed_rpc import PyRRef from torch.futures import Future @@ -15,7 +14,6 @@ class FillDrainWorker(WorkerBase): - def _get_work_item_key(self) -> UniqueKey: # execute backward first (if backward phase in work_list) num_microbatches = self.num_microbatches @@ -33,29 +31,40 @@ def _get_work_item_key(self) -> UniqueKey: class FillDrainPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" use_1F1B = False - super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) + super().__init__( + FillDrainWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) class OneFOneBWorker(WorkerBase): - def _get_work_item_key(self) -> UniqueKey: # execute backward first (if backward phase in work_list) pp_rank = self.pp_rank @@ -77,8 +86,7 @@ def _get_work_item_key(self) -> UniqueKey: # change outstanding_range at: # 1. forward times reach actual_stage_num, this is the end of continuous forward # 2. forward times reach num_microbatches, this is the end of 1F1B mode - if not is_last_stage and \ - target_key.phase == Phase.FORWARD: + if not is_last_stage and target_key.phase == Phase.FORWARD: if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 outstanding_min = actual_stage_num - pp_rank - 1 @@ -91,30 +99,41 @@ def _get_work_item_key(self) -> UniqueKey: class OneFOneBPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" use_1F1B = True - super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) + super().__init__( + OneFOneBWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) class ChimeraWorker(WorkerBase): - def _get_producer_consumer(self) -> None: rank = self.pp_rank min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num @@ -143,11 +162,12 @@ def _get_work_item_key(self) -> UniqueKey: forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num forward_block_num = self.forward_times // forward_block_size - if self.forward_times >= real_microbatch_num or \ - ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): + if self.forward_times >= real_microbatch_num or ( + (pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times + ): target_phase = Phase.BACKWARD target_microbatch_id = self.backward_times - else: # others + else: # others target_phase = Phase.FORWARD target_microbatch_id = self.forward_times @@ -168,7 +188,7 @@ def _initialize_partition(self): # from corresponding up stage pp_rank = self.pp_rank stage_num = self.actual_stage_num - device = self.device + self.device if pp_rank < stage_num: super()._initialize_partition() else: @@ -242,27 +262,38 @@ def _hook_before_step(self): class ChimeraPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - - assert num_microbatches % stage_num == 0, \ - "In Chimera, num_microbatches must be the multiply of stage_num!" + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: + assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!" use_1F1B = False chunk = 1 - super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) - - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): + super().__init__( + ChimeraWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) + + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): pass def _create_pp_rank_to_rpc_worker_id(self) -> None: diff --git a/colossalai/legacy/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py index d1033fbde920..808de301a2a0 100644 --- a/colossalai/legacy/pipeline/rpc/utils.py +++ b/colossalai/legacy/pipeline/rpc/utils.py @@ -1,7 +1,7 @@ import argparse import os import warnings -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import torch import torch.distributed.rpc as rpc @@ -61,7 +61,7 @@ def get_batch_lengths(batch): def split_batch(batch: Any, start, stop, device: str): - if device == 'cuda': + if device == "cuda": fn = lambda x: x[start:stop].cuda() else: fn = lambda x: x[start:stop] @@ -102,8 +102,8 @@ def get_real_args_kwargs(args_or_kwargs): def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -112,15 +112,17 @@ def run_worker(rank, args, master_func): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) ppg.args = args # in rpc mode, only rank 0 is needed to be coded if rank == 0: @@ -139,17 +141,17 @@ def rpc_run(args, master_func): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=int, default=128) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=int, default=128) return parser.parse_args() diff --git a/colossalai/legacy/pipeline/utils.py b/colossalai/legacy/pipeline/utils.py index be8428692756..182af677c047 100644 --- a/colossalai/legacy/pipeline/utils.py +++ b/colossalai/legacy/pipeline/utils.py @@ -38,8 +38,7 @@ def _binary_partition(weights: List, start: int, end: int): def _heap_addition(weights: List, intervals: int, add_cnt: int): - """ - """ + """ """ def _heap_push(heap, st, ed): value = weights[ed - 1] @@ -113,8 +112,9 @@ def _binary_search(weights, num): def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recommended" + assert ( + num_items % num_chunks == 0 + ), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended" logger = get_dist_logger() parts = [[] for _ in range(pipeline_parallel_size)] @@ -162,7 +162,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): elif isinstance(input_tensor, torch.Tensor): kwargs_offset = 1 elif isinstance(input_tensor, (tuple, OrderedDict)): - #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' # Huggingface will take their own structures based on OrderedDict as the output # between layers so we've to close this check. kwargs_offset = len(input_tensor) @@ -204,21 +204,21 @@ def foo(attention_mask=None): kwargs[k] = rst return input_tensor if isinstance(input_tensor, tuple): - assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' + assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None." sig = inspect.signature(func) func_args_num = len(sig.parameters) assert func_args_num <= len( - input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' + input_tensor + ), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}." if func_args_num < len(input_tensor): return func(*input_tensor[:func_args_num]) else: return func(*input_tensor) - assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' + assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple." return func(input_tensor) def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): - assert func_key in func_dict, f"{func_key} is not in the function_dict." funcs_to_exec = func_dict[func_key] if isinstance(funcs_to_exec, list): @@ -243,7 +243,7 @@ def call_module(module, args=None, kwargs=None): forward_func = module.forward sig = inspect.signature(forward_func) param_nums = len(sig.parameters) - feed_nums = len(args) + len(kwargs) + len(args) + len(kwargs) args_needed_nums = param_nums - len(kwargs) args_needed = args[:args_needed_nums] if isinstance(module, CheckpointModule): @@ -256,17 +256,17 @@ def call_module(module, args=None, kwargs=None): def customized_partition(exec_seq): - ''' + """ This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an annotation to note the partition point. - ''' + """ customized_parts = {} start = 0 stop = 0 rank = 0 for element in exec_seq: if isinstance(element, str): - if element == 'SPLIT_NODE': + if element == "SPLIT_NODE": customized_parts[rank] = [(start, stop)] start = stop rank += 1 diff --git a/colossalai/legacy/registry/registry.py b/colossalai/legacy/registry/registry.py index 50d6b74c5617..43644f8a9e73 100644 --- a/colossalai/legacy/registry/registry.py +++ b/colossalai/legacy/registry/registry.py @@ -59,7 +59,7 @@ def get_module(self, module_name: str): for lib in self._third_party_lib: if hasattr(lib, module_name): return getattr(lib, module_name) - raise NameError(f'Module {module_name} not found in the registry {self.name}') + raise NameError(f"Module {module_name} not found in the registry {self.name}") def has(self, module_name: str): """Searches for a module with name `module_name` and returns a boolean value indicating diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py index d3278bf1e420..a34870eba068 100644 --- a/colossalai/legacy/tensor/__init__.py +++ b/colossalai/legacy/tensor/__init__.py @@ -6,12 +6,12 @@ from .tensor_spec import ColoTensorSpec __all__ = [ - 'ComputePattern', - 'ComputeSpec', - 'distspec', - 'DistSpecManager', - 'ProcessGroup', - 'ColoTensorSpec', - 'ShardSpec', - 'ReplicaSpec', + "ComputePattern", + "ComputeSpec", + "distspec", + "DistSpecManager", + "ProcessGroup", + "ColoTensorSpec", + "ShardSpec", + "ReplicaSpec", ] diff --git a/colossalai/legacy/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py index 12f8f36bc613..820aafab687f 100644 --- a/colossalai/legacy/tensor/compute_spec.py +++ b/colossalai/legacy/tensor/compute_spec.py @@ -23,7 +23,7 @@ def __init__(self, compute_pattern: ComputePattern) -> None: self.output_replicate = True def __repr__(self): - return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})' + return f"ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})" def set_output_replicate(self, flag: bool = True): self.output_replicate = flag diff --git a/colossalai/legacy/tensor/const.py b/colossalai/legacy/tensor/const.py index 356e8ecc885a..cbc2b29d66a8 100644 --- a/colossalai/legacy/tensor/const.py +++ b/colossalai/legacy/tensor/const.py @@ -3,4 +3,4 @@ class TensorType(Enum): MODEL = 0 - NONMODEL = 1 # mainly activations + NONMODEL = 1 # mainly activations diff --git a/colossalai/legacy/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py index d97308b04bef..3942b5b7a33c 100644 --- a/colossalai/legacy/tensor/dist_spec_mgr.py +++ b/colossalai/legacy/tensor/dist_spec_mgr.py @@ -20,14 +20,12 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator class TransformDistSpec(torch.autograd.Function): - @staticmethod def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): ctx.old_dist_spec = old_dist_spec @@ -38,12 +36,17 @@ def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backw @staticmethod def backward(ctx, grad_outputs): - return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, - ctx.pg), None, None, None, None, None + return ( + ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg), + None, + None, + None, + None, + None, + ) class DistSpecManager: - _use_autograd_function: bool = True @staticmethod @@ -51,8 +54,9 @@ def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: pass @staticmethod - def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def _shard_as( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: """_shard_as: shard the tensor w.r.t a distributed specification. Assuming the tensor passed in is a global (replicated) tensor. Args: @@ -62,7 +66,9 @@ def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSp Returns: torch.Tensor: a torch tensor after sharded. """ - assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + assert ( + old_dist_spec.placement.value == "r" + ), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" DistSpecManager._sanity_check(old_dist_spec, dist_spec) chunk = tensor @@ -86,9 +92,9 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> Returns: torch.Tensor: a replicated tensor. """ - assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" + assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!" is_cpu_tensor = False - if tensor.device.type == 'cpu': + if tensor.device.type == "cpu": # pytorch lower than 1.11 dose not support gather a cpu tensor. # Therefore, we transfer tensor to GPU before gather. saved_dev = tensor.device @@ -96,14 +102,14 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> is_cpu_tensor = True buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] - assert tensor.device.type == 'cuda' + assert tensor.device.type == "cuda" dist.all_gather(buffer, tensor, group=pg.tp_process_group()) for i in range(len(old_dist_spec.dims) - 1, -1, -1): new_buffer = [] dim = old_dist_spec.dims[i] num_parts = old_dist_spec.num_partitions[i] for start in range(0, len(buffer), num_parts): - new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) + new_buffer.append(torch.cat(buffer[start : start + num_parts], dim)) buffer = new_buffer assert len(buffer) == 1 @@ -112,15 +118,17 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> return buffer[0] @staticmethod - def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def _all_to_all( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: world_size = pg.tp_world_size() if world_size == 1: return tensor - assert tensor.device.type == "cuda", \ - "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ + assert tensor.device.type == "cuda", ( + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " f"collective function, however, we got {tensor.device.type} device" + ) gather_dim = old_dist_spec.dims[0] scatter_dim = dist_spec.dims[0] @@ -164,8 +172,9 @@ def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, p return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) @staticmethod - def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def handle_trans_spec( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" @@ -174,7 +183,7 @@ def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: (DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r, (DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s, (DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r, - (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s + (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s, } forward_trans_handle = trans_funcs[trans_func_key] @@ -183,8 +192,9 @@ def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)] - return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, - backward_trans_handle) + return TransformDistSpec.apply( + tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle + ) @staticmethod @contextmanager diff --git a/colossalai/legacy/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py index 3a09f1426e31..efef9904ec10 100644 --- a/colossalai/legacy/tensor/distspec.py +++ b/colossalai/legacy/tensor/distspec.py @@ -1,12 +1,12 @@ from enum import Enum from typing import List -__all__ = ['ReplicaSpec', 'ShardSpec'] +__all__ = ["ReplicaSpec", "ShardSpec"] class DistPlacementPattern(Enum): - REPLICATE = 'r' - SHARD = 's' + REPLICATE = "r" + SHARD = "s" class _DistSpec: @@ -25,7 +25,6 @@ class _DistSpec: """ def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): - self.placement = dist_placement_pattern for k, v in meta_info.items(): setattr(self, k, v) @@ -34,15 +33,15 @@ def __eq__(self, other: "_DistSpec") -> bool: if dir(self) != dir(other): return False for attr in dir(self): - if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): + if not attr.startswith("__") and getattr(self, attr) != getattr(other, attr): return False return True def __repr__(self) -> str: attr_list = [] for attr in dir(self): - if not attr.startswith('__'): - attr_list.append(f'{attr}={str(getattr(self, attr))}') + if not attr.startswith("__"): + attr_list.append(f"{attr}={str(getattr(self, attr))}") attr_str = ", ".join(attr_list) return "DistSpec(" + attr_str + ")" diff --git a/colossalai/legacy/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py index 8d2e9a616d76..ec6043163336 100644 --- a/colossalai/legacy/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -7,13 +7,12 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta): - def __init__(self): # distributed settings # use this dict to record all Pytorch ProcessGroups self.dict = {} # set a distributed logger - self.logger = get_dist_logger('ProcessGroup') + self.logger = get_dist_logger("ProcessGroup") def log_pg_init(self, rank_list: List[int], backend: str): str_list = ["Pytorch ProcessGroup Init:"] @@ -21,9 +20,8 @@ def log_pg_init(self, rank_list: List[int], backend: str): str_list.append(f"ranks: {rank_list}") self.logger.info("\n\t".join(str_list), ranks=[0]) - def get(self, rank_list: List[int], backend: str = 'nccl'): - """Reuse Pytorch ProcessGroup when such a group is initialized - """ + def get(self, rank_list: List[int], backend: str = "nccl"): + """Reuse Pytorch ProcessGroup when such a group is initialized""" # we need to convert the passed list to a tuple # since List is unhashable processgroup_key = (backend, tuple(rank_list)) @@ -51,11 +49,13 @@ class ProcessGroup: dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). """ - def __init__(self, - rank: Optional[int] = None, - ranks: Optional[List[int]] = None, - tp_degree: Optional[int] = None, - dp_degree: Optional[int] = None) -> None: + def __init__( + self, + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, + tp_degree: Optional[int] = None, + dp_degree: Optional[int] = None, + ) -> None: if not torch.distributed.is_initialized(): self.is_init = False return @@ -64,13 +64,13 @@ def __init__(self, self._rank = torch.distributed.get_rank() if rank is not None: - assert self._rank == rank # make sure that the global rank is correct + assert self._rank == rank # make sure that the global rank is correct if ranks is None: self._rank_list = list(range(torch.distributed.get_world_size())) else: self._rank_list = ranks - self._rank_list.sort() # ensure that the list is in order + self._rank_list.sort() # ensure that the list is in order self._world_size = len(self._rank_list) @@ -79,31 +79,36 @@ def __init__(self, self._tp_degree = 1 elif dp_degree and not tp_degree: self._dp_degree = dp_degree - assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" + assert ( + self._world_size % self._dp_degree == 0 + ), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" self._tp_degree = self._world_size // dp_degree elif not dp_degree and tp_degree: self._tp_degree = tp_degree - assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" + assert ( + self._world_size % self._tp_degree == 0 + ), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree else: self._dp_degree = dp_degree self._tp_degree = tp_degree - assert self._dp_degree * self._tp_degree == self._world_size, \ - f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + assert self._dp_degree * self._tp_degree == self._world_size, ( + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" f"and TP degree {self._tp_degree}" + ) self._tp_rank_list = None self._dp_rank_list = None for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'nccl') + PYTORCHPGDICT_.get(i_tp_list, "nccl") if self._rank in i_tp_list: self._tp_rank_list = i_tp_list for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'nccl') + PYTORCHPGDICT_.get(j_dp_list, "nccl") if self._rank in j_dp_list: self._dp_rank_list = j_dp_list @@ -119,11 +124,11 @@ def set_cpu_groups(self): for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'gloo') + PYTORCHPGDICT_.get(i_tp_list, "gloo") for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'gloo') + PYTORCHPGDICT_.get(j_dp_list, "gloo") self._has_cpu_groups = True @@ -145,7 +150,7 @@ def __repr__(self): else: return "ProcessGroup not initialized" - def __eq__(self, obj: 'ProcessGroup') -> bool: + def __eq__(self, obj: "ProcessGroup") -> bool: if not isinstance(obj, ProcessGroup): return False if self._rank != obj._rank: @@ -260,7 +265,7 @@ def dp_process_group(self): Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ - return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl") def tp_process_group(self): """tp_process_group @@ -270,7 +275,7 @@ def tp_process_group(self): Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ - return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl") def cpu_dp_process_group(self): """cpu_dp_process_group @@ -283,7 +288,7 @@ def cpu_dp_process_group(self): `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo") def cpu_tp_process_group(self): """cpu_tp_process_group @@ -296,7 +301,7 @@ def cpu_tp_process_group(self): `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo") def get_ranks_in_dp(self) -> List[int]: """get_ranks_in_dp diff --git a/colossalai/legacy/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py index aa792e507639..5bdd384e5e15 100644 --- a/colossalai/legacy/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -9,12 +9,13 @@ @dataclass class ColoTensorSpec: - """ ColoTensorSpec + """ColoTensorSpec A data class for specifications of the `ColoTensor`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. """ + pg: ProcessGroup dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) compute_attr: Optional[ComputeSpec] = None diff --git a/colossalai/legacy/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py index 84e53dc4e87a..e4fddc7c1c9f 100644 --- a/colossalai/legacy/trainer/__init__.py +++ b/colossalai/legacy/trainer/__init__.py @@ -1,3 +1,3 @@ from ._trainer import Trainer -__all__ = ['Trainer'] +__all__ = ["Trainer"] diff --git a/colossalai/legacy/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py index 1cb99fcc90ed..46e708622237 100644 --- a/colossalai/legacy/trainer/_trainer.py +++ b/colossalai/legacy/trainer/_trainer.py @@ -151,7 +151,7 @@ def _call_hooks(self, func, output=None): @staticmethod def _should_display_progress(display_progress: bool): """Only display progress on DP rank 0, TP rank 0 and PP last rank""" - return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) + return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() def _train_epoch( self, @@ -293,8 +293,7 @@ def fit( assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" for hook in hooks: - assert isinstance(hook, BaseHook), \ - f'expected the hook to be of type BaseHook, but got {type(hook)}' + assert isinstance(hook, BaseHook), f"expected the hook to be of type BaseHook, but got {type(hook)}" else: hooks = [] self.hooks = hooks diff --git a/colossalai/legacy/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py index bf9cc6421b67..290aeb64a04d 100644 --- a/colossalai/legacy/trainer/hooks/__init__.py +++ b/colossalai/legacy/trainer/hooks/__init__.py @@ -11,7 +11,16 @@ from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook __all__ = [ - 'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook', - 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook', - 'SaveCheckpointHook' + "BaseHook", + "MetricHook", + "LossHook", + "AccuracyHook", + "LogMetricByEpochHook", + "TensorboardHook", + "LogTimingByEpochHook", + "LogMemoryByEpochHook", + "LRSchedulerHook", + "ThroughputHook", + "LogMetricByStepHook", + "SaveCheckpointHook", ] diff --git a/colossalai/legacy/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py index cca8e081ec88..fc883134203f 100644 --- a/colossalai/legacy/trainer/hooks/_base_hook.py +++ b/colossalai/legacy/trainer/hooks/_base_hook.py @@ -18,24 +18,16 @@ def __init__(self, priority: int) -> None: self.priority = priority def after_hook_is_attached(self, trainer): - """Actions after hooks are attached to trainer. - """ - pass + """Actions after hooks are attached to trainer.""" def before_train(self, trainer): - """Actions before training. - """ - pass + """Actions before training.""" def after_train(self, trainer): - """Actions after training. - """ - pass + """Actions after training.""" def before_train_iter(self, trainer): - """Actions before running a training iteration. - """ - pass + """Actions before running a training iteration.""" def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): """Actions after running a training iteration. @@ -46,42 +38,27 @@ def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor) label (:class:`torch.Tensor`): Labels of the input data. loss (:class:`torch.Tensor`): Loss between the output and input data. """ - pass def before_train_epoch(self, trainer): - """Actions before starting a training epoch. - """ - pass + """Actions before starting a training epoch.""" def after_train_epoch(self, trainer): - """Actions after finishing a training epoch. - """ - pass + """Actions after finishing a training epoch.""" def before_test(self, trainer): - """Actions before evaluation. - """ - pass + """Actions before evaluation.""" def after_test(self, trainer): - """Actions after evaluation. - """ - pass + """Actions after evaluation.""" def before_test_epoch(self, trainer): - """Actions before starting a testing epoch. - """ - pass + """Actions before starting a testing epoch.""" def after_test_epoch(self, trainer): - """Actions after finishing a testing epoch. - """ - pass + """Actions after finishing a testing epoch.""" def before_test_iter(self, trainer): - """Actions before running a testing iteration. - """ - pass + """Actions before running a testing iteration.""" def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): """Actions after running a testing iteration. @@ -92,7 +69,6 @@ def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): label (:class:`torch.Tensor`): Labels of the input data loss (:class:`torch.Tensor`): Loss between the output and input data """ - pass def init_runner_states(self, trainer, key, val): """Initializes trainer's state. diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py index cda10030bf65..50c80759867e 100644 --- a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py @@ -27,12 +27,14 @@ class SaveCheckpointHook(BaseHook): depend on the hooks order in the hook list. """ - def __init__(self, - interval: int = 1, - checkpoint_dir: str = None, - model: torch.nn.Module = None, - save_by_iter: bool = False, - priority: int = 10): + def __init__( + self, + interval: int = 1, + checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, + priority: int = 10, + ): super().__init__(priority=priority) self.interval = interval self.checkpoint_dir = checkpoint_dir @@ -52,22 +54,23 @@ def after_hook_is_attached(self, trainer): self.model = self.model if self.model is not None else trainer.engine.model def after_train_iter(self, trainer, output, label, loss): - """Saves the model after a training iter. - """ + """Saves the model after a training iter.""" # save by interval if self.save_by_iter and trainer.cur_step % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', - ranks=[0]) + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info( + f"checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}", ranks=[0] + ) else: pass def after_train_epoch(self, trainer): - """Saves the model after a training epoch. - """ + """Saves the model after a training epoch.""" # save by interval if trainer.cur_epoch % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info(f"checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}", ranks=[0]) diff --git a/colossalai/legacy/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py index 4923b8cba6c0..18da38298704 100644 --- a/colossalai/legacy/trainer/hooks/_commons_.py +++ b/colossalai/legacy/trainer/hooks/_commons_.py @@ -3,7 +3,7 @@ def _format_number(val, prec=5): if isinstance(val, float): - return f'{val:.{prec}g}' + return f"{val:.{prec}g}" elif torch.is_tensor(val) and torch.is_floating_point(val): - return f'{val.item():.{prec}g}' + return f"{val.item():.{prec}g}" return val diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py index b1a398ce7f71..c1cf0ca5228b 100644 --- a/colossalai/legacy/trainer/hooks/_log_hook.py +++ b/colossalai/legacy/trainer/hooks/_log_hook.py @@ -51,20 +51,20 @@ def __init__(self, priority: int = 10): super().__init__(priority) def after_train_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["train"].items(): if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() def after_test_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["test"].items(): if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() @HOOKS.register_module @@ -85,24 +85,24 @@ def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: def _get_str(self, trainer, mode): msg = [] - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): - msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') - msg = ' | '.join(msg) + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): + msg.append(f"{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}") + msg = " | ".join(msg) return msg def after_train_epoch(self, trainer): if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='train') + msg = self._get_str(trainer=trainer, mode="train") if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}') + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg}") # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') def after_test_epoch(self, trainer): if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='test') + msg = self._get_str(trainer=trainer, mode="test") if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') @@ -145,8 +145,11 @@ def __init__( self._is_valid_rank_to_log = True # check for - if gpc.is_initialized(ParallelMode.PIPELINE) and \ - not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log: + if ( + gpc.is_initialized(ParallelMode.PIPELINE) + and not gpc.is_last_rank(ParallelMode.PIPELINE) + and self._is_valid_rank_to_log + ): raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") if self._is_valid_rank_to_log: @@ -157,38 +160,38 @@ def __init__( rank = 0 # create workspace - log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}') + log_dir = osp.join(log_dir, f"{parallel_mode}_rank_{rank}") os.makedirs(log_dir, exist_ok=True) - self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}') + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f"_rank_{rank}") def _log_by_iter(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): if metric_calculator.epoch_only: continue val = metric_calculator.get_last_step_value() if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) def _log_by_epoch(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): if metric_calculator.epoch_only: val = metric_calculator.get_accumulated_value() if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) def after_test_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='test') + self._log_by_iter(trainer, mode="test") def after_test_epoch(self, trainer): - self._log_by_epoch(trainer, mode='test') + self._log_by_epoch(trainer, mode="test") def after_train_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='train') + self._log_by_iter(trainer, mode="train") def after_train_epoch(self, trainer): - self._log_by_epoch(trainer, mode='train') + self._log_by_epoch(trainer, mode="train") @HOOKS.register_module @@ -206,13 +209,15 @@ class LogTimingByEpochHook(LogByEpochHook): ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. """ - def __init__(self, - timer: MultiTimer, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - ignore_num_train_steps: int = 0) -> None: + def __init__( + self, + timer: MultiTimer, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + ignore_num_train_steps: int = 0, + ) -> None: super().__init__(logger=logger, interval=interval, priority=priority) self._timer = timer self._log_eval = log_eval @@ -229,33 +234,31 @@ def _get_message(self, mode): if timer_name.startswith(mode): last_elapsed_time = timer.get_elapsed_time() if timer.has_history: - if timer_name == 'Train-step' and not self._is_train_step_history_trimmed: - timer._history = timer._history[self._ignore_num_train_steps:] + if timer_name == "Train-step" and not self._is_train_step_history_trimmed: + timer._history = timer._history[self._ignore_num_train_steps :] self._is_train_step_history_trimmed = True history_mean = timer.get_history_mean() - history_sum = timer.get_history_sum() + timer.get_history_sum() msg.append( - f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s' + f"{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s" ) else: - msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s') + msg.append(f"{timer_name}: last = {_format_number(last_elapsed_time)} s") - msg = ' | '.join(msg) + msg = " | ".join(msg) return msg def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ + """Writes log after finishing a training epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - msg = self._get_message('Train') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') + msg = self._get_message("Train") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}") def after_test_epoch(self, trainer): - """Writes log after finishing a testing epoch. - """ + """Writes log after finishing a testing epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - msg = self._get_message('Test') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + msg = self._get_message("Test") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") @HOOKS.register_module @@ -272,31 +275,28 @@ class LogMemoryByEpochHook(LogByEpochHook): """ def __init__( - self, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - report_cpu: bool = False, # no reference + self, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + report_cpu: bool = False, # no reference ) -> None: super().__init__(logger=logger, interval=interval, priority=priority) self._log_eval = log_eval self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() def before_train(self, trainer): - """Resets before training. - """ + """Resets before training.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage('Before-train', self.logger) + report_memory_usage("Before-train", self.logger) def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ + """Writes log after finishing a training epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger) + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Train]", self.logger) def after_test(self, trainer): - """Reports after testing. - """ + """Reports after testing.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger) + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Test]", self.logger) diff --git a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py index 6d60966da12a..d14db563473c 100644 --- a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py @@ -34,15 +34,16 @@ def __init__( def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) - trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, - initial_lr=self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"] = LearningRateMetric( + epoch_only=self.by_epoch, initial_lr=self.lr_scheduler.get_last_lr()[0] + ) def after_train_epoch(self, trainer): if self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): if not self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 899e4d08a5c9..35a7f0a156ab 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -35,8 +35,7 @@ def __init__(self, epoch_only: bool): @property def epoch_only(self): - """Returns :attr:`epoch_only`. - """ + """Returns :attr:`epoch_only`.""" return self._epoch_only @abstractmethod @@ -44,20 +43,16 @@ def reset(self) -> None: """Resets the metric to it's initial state. By default, this is called at the start of each epoch. """ - pass @abstractmethod def update(self, *args, **kwargs) -> None: """Updates the metric's state using the passed batch output. By default, this is called once for each batch. """ - pass @abstractmethod def get_last_step_value(self) -> float: - """Returns the metric value in the last iteration. - """ - pass + """Returns the metric value in the last iteration.""" @abstractmethod def get_accumulated_value(self): @@ -67,7 +62,6 @@ def get_accumulated_value(self): :return: the actual quantity of interest :rtype: Any """ - pass @staticmethod @abstractmethod @@ -77,7 +71,6 @@ def is_better(a, b) -> bool: :return: The result of comparison :rtype: bool """ - pass class LossMetric(Metric): @@ -94,8 +87,7 @@ def __init__(self, epoch_only): self.count = 0 def reset(self) -> None: - """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. - """ + """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.""" self.last_step_loss.zero_() self.accum_loss.zero_() self.count = 0 @@ -114,8 +106,7 @@ def update(self, loss) -> None: self.count += 1 def get_accumulated_value(self): - """Returns accumulated loss. - """ + """Returns accumulated loss.""" if gpc.is_initialized(ParallelMode.DATA): dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) @@ -124,8 +115,7 @@ def get_accumulated_value(self): return self.accum_loss.item() def get_last_step_value(self) -> float: - """Returns :attr:`last_step_loss`. - """ + """Returns :attr:`last_step_loss`.""" return self.last_step_loss.cpu().item() @staticmethod @@ -141,7 +131,7 @@ class LearningRateMetric(Metric): initial_lr (float, optional): Initial learning rate, defaults to 0.0. """ - def __init__(self, epoch_only: bool, initial_lr: float = 0.): + def __init__(self, epoch_only: bool, initial_lr: float = 0.0): super().__init__(epoch_only=epoch_only) self.lr = initial_lr @@ -241,8 +231,8 @@ def __init__( self._is_stage_to_compute = is_no_pp_or_last_stage() def _check_metric_states_initialization(self, trainer): - if 'metrics' not in trainer.states: - self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) + if "metrics" not in trainer.states: + self.init_runner_states(trainer, "metrics", dict(train={}, test={})) @HOOKS.register_module @@ -266,8 +256,8 @@ def after_hook_is_attached(self, trainer): self.test_loss = LossMetric(epoch_only=True) # register the metric calculator - trainer.states['metrics']['train']['Loss'] = self.train_loss - trainer.states['metrics']['test']['Loss'] = self.test_loss + trainer.states["metrics"]["train"]["Loss"] = self.train_loss + trainer.states["metrics"]["test"]["Loss"] = self.test_loss def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -307,7 +297,7 @@ def after_hook_is_attached(self, trainer): self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) # register the metric - trainer.states['metrics']['test']['Accuracy'] = self.metric + trainer.states["metrics"]["test"]["Accuracy"] = self.metric def before_test(self, trainer): if self._is_stage_to_compute: @@ -356,8 +346,9 @@ def get_last_step_value(self) -> float: if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -367,8 +358,9 @@ def get_last_step_info(self) -> str: if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -379,8 +371,9 @@ def get_last_step_info(self) -> str: return f"{sample_per_sec} sample_per_sec" def get_accumulated_value(self) -> float: - self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() @@ -411,14 +404,16 @@ def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: i def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: - self.metric = ThroughputMetric(epoch_only=True, - ignored_steps=self.ignored_steps, - tflop_per_step=self._tflop_per_step, - use_local=self._use_local) + self.metric = ThroughputMetric( + epoch_only=True, + ignored_steps=self.ignored_steps, + tflop_per_step=self._tflop_per_step, + use_local=self._use_local, + ) # register the metric - trainer.states['metrics']['train']['Throughput'] = self.metric - trainer.states['metrics']['test']['Throughput'] = self.metric + trainer.states["metrics"]["train"]["Throughput"] = self.metric + trainer.states["metrics"]["test"]["Throughput"] = self.metric def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -426,8 +421,9 @@ def before_train_epoch(self, trainer): def after_train_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Train-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Train-step").get_elapsed_time() + ) def before_test(self, trainer): if self._is_stage_to_compute: @@ -435,5 +431,6 @@ def before_test(self, trainer): def after_test_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Test-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Test-step").get_elapsed_time() + ) diff --git a/colossalai/legacy/utils/__init__.py b/colossalai/legacy/utils/__init__.py index ae358f8bebcb..86984edeec65 100644 --- a/colossalai/legacy/utils/__init__.py +++ b/colossalai/legacy/utils/__init__.py @@ -26,28 +26,28 @@ ) __all__ = [ - 'DataParallelSampler', - 'get_dataloader', - 'save_checkpoint', - 'load_checkpoint', - 'colo_device_memory_capacity', - 'colo_device_memory_used', - 'colo_get_cpu_memory_capacity', - 'colo_set_cpu_memory_capacity', - 'colo_set_process_memory_fraction', - 'report_memory_usage', - 'clip_grad_norm_fp32', - 'copy_tensor_parallel_attributes', - 'count_zeros_fp32', - 'is_dp_rank_0', - 'is_model_parallel_parameter', - 'is_no_pp_or_last_stage', - 'is_tp_rank_0', - 'is_using_ddp', - 'is_using_pp', - 'is_using_sequence', - 'param_is_not_tensor_parallel_duplicate', - 'print_rank_0', - 'switch_virtual_pipeline_parallel_rank', - 'sync_model_param', + "DataParallelSampler", + "get_dataloader", + "save_checkpoint", + "load_checkpoint", + "colo_device_memory_capacity", + "colo_device_memory_used", + "colo_get_cpu_memory_capacity", + "colo_set_cpu_memory_capacity", + "colo_set_process_memory_fraction", + "report_memory_usage", + "clip_grad_norm_fp32", + "copy_tensor_parallel_attributes", + "count_zeros_fp32", + "is_dp_rank_0", + "is_model_parallel_parameter", + "is_no_pp_or_last_stage", + "is_tp_rank_0", + "is_using_ddp", + "is_using_pp", + "is_using_sequence", + "param_is_not_tensor_parallel_duplicate", + "print_rank_0", + "switch_virtual_pipeline_parallel_rank", + "sync_model_param", ] diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index add690f28cc0..387e1c54ec87 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -28,7 +28,6 @@ def copy_to_device(obj, device): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) @@ -42,7 +41,7 @@ def forward(ctx, run_function, activation_offload=False, *args): ctx.fwd_seed_states = get_states(copy=True) ctx.fwd_current_mode = get_current_mode() - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): ctx.had_autocast_in_fwd = torch.is_autocast_enabled() else: ctx.had_autocast_in_fwd = False @@ -62,7 +61,7 @@ def forward(ctx, run_function, activation_offload=False, *args): for i, arg in enumerate(args): if torch.is_tensor(arg): if activation_offload: - tensor_inputs.append(copy_to_device(arg, 'cpu')) + tensor_inputs.append(copy_to_device(arg, "cpu")) else: tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -79,8 +78,10 @@ def forward(ctx, run_function, activation_offload=False, *args): @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is " - "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.") + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument." + ) # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices @@ -131,8 +132,7 @@ def backward(ctx, *args): outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: - raise RuntimeError("none of output has requires_grad=True," - " this checkpoint() is not necessary") + raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + grads @@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): fwd_current_mode = get_current_mode() # check if use autocast - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): has_autocast_in_fwd = torch.is_autocast_enabled() else: has_autocast_in_fwd = False @@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): weak_holder_list = [] # class for weakref.ref - class Holder(): + class Holder: pass # return a Holder object for later unpack process @@ -226,19 +226,20 @@ def inner_unpack(packed): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), \ - torch.cuda.amp.autocast(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks( + inner_pack, inner_unpack + ): _unused = function(*args) else: - with torch.enable_grad(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) if x not in storage: - raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" - " recomputation being triggered in between, this is not currently supported. Please" - " open an issue with details on your use case so that we can prioritize adding this.") + raise RuntimeError( + "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this." + ) return storage[x] diff --git a/colossalai/legacy/utils/checkpoint/__init__.py b/colossalai/legacy/utils/checkpoint/__init__.py index 558a956b31ac..35ce19ea1c69 100644 --- a/colossalai/legacy/utils/checkpoint/__init__.py +++ b/colossalai/legacy/utils/checkpoint/__init__.py @@ -1,3 +1,3 @@ from .module_checkpoint import load_checkpoint, save_checkpoint -__all__ = ['save_checkpoint', 'load_checkpoint'] +__all__ = ["save_checkpoint", "load_checkpoint"] diff --git a/colossalai/legacy/utils/checkpoint/module_checkpoint.py b/colossalai/legacy/utils/checkpoint/module_checkpoint.py index 9bd2907abf9d..1d691e5c8f97 100644 --- a/colossalai/legacy/utils/checkpoint/module_checkpoint.py +++ b/colossalai/legacy/utils/checkpoint/module_checkpoint.py @@ -9,13 +9,15 @@ from .utils import gather_tensor, scatter_tensor -def save_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[OptimizerWrapper] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - *args, - **kwargs): +def save_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs, +): """save_checkpoint save a model, whose parameters are `ColoTensor`s. Args: @@ -30,7 +32,7 @@ def save_checkpoint(path: str, # save the dist context about the tensors in a new dict, while still maintain the original dict. for k, v in model_state.items(): if isinstance(v, ColoTensor): - gather_tensor(v) # gather shared tensors to rank0 + gather_tensor(v) # gather shared tensors to rank0 # don't recover tensors in rank0, since the dict is only a copy of model if rank == 0: @@ -39,10 +41,10 @@ def save_checkpoint(path: str, if isinstance(v, ColoTensor): assert v.save_ready assert v.is_replicate() - delattr(v, 'save_ready') + delattr(v, "save_ready") # model saving - save_state = {'epoch': epoch, 'model': model_state} - torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) + save_state = {"epoch": epoch, "model": model_state} + torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs) # delete old dicts del model_state @@ -52,35 +54,37 @@ def save_checkpoint(path: str, if optimizer is not None: mapping = dict() optim_state = optimizer.state_dict() - for k, v in optim_state['state'].items(): + for k, v in optim_state["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): mapping[(k, n)] = t.dist_spec gather_tensor(t) if rank == 0: - save_state = {'epoch': epoch, 'optim': optim_state} - torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) + save_state = {"epoch": epoch, "optim": optim_state} + torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs) # recover colo tensors in rank0 - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): - assert hasattr(t, 'save_ready') + assert hasattr(t, "save_ready") t.set_dist_spec(mapping[(k, n)]) - delattr(t, 'save_ready') + delattr(t, "save_ready") del optim_state del mapping dist.barrier() -def load_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[OptimizerWrapper] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - torch_load_kwargs: Optional[Dict] = None, - load_state_dict_kwargs: Optional[Dict] = None): +def load_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + torch_load_kwargs: Optional[Dict] = None, + load_state_dict_kwargs: Optional[Dict] = None, +): """load_checkpoint load a model, whose parameters are `ColoTensor`s. Args: @@ -106,8 +110,8 @@ def load_checkpoint(path: str, gather_tensor(p) if rank == 0: - load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) - model.load_state_dict(load_state['model'], **load_state_dict_kwargs) + load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs) + model.load_state_dict(load_state["model"], **load_state_dict_kwargs) dist.barrier() # scatter loaded parameters @@ -115,24 +119,24 @@ def load_checkpoint(path: str, if isinstance(p, ColoTensor): scatter_tensor(p, mapping[n]) if rank == 0: - assert hasattr(p, 'save_ready') - delattr(p, 'save_ready') + assert hasattr(p, "save_ready") + delattr(p, "save_ready") del mapping if optimizer is not None: mapping = dict() - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): mapping[(k, n)] = t.dist_spec gather_tensor(t) if rank == 0: - colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) - optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) + colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs) + optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs) dist.barrier() - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): scatter_tensor(t, mapping[(k, n)]) diff --git a/colossalai/legacy/utils/checkpoint/utils.py b/colossalai/legacy/utils/checkpoint/utils.py index c830d4811463..c56848cf06c4 100644 --- a/colossalai/legacy/utils/checkpoint/utils.py +++ b/colossalai/legacy/utils/checkpoint/utils.py @@ -8,7 +8,7 @@ def robust_broadcast(tensor): with torch.no_grad(): - is_cpu_ten = tensor.device.type == 'cpu' + is_cpu_ten = tensor.device.type == "cpu" if is_cpu_ten: b_data = tensor.cuda() else: @@ -21,8 +21,7 @@ def robust_broadcast(tensor): def gather_tensor(colo_tensor: ColoTensor) -> None: - """Make colo_tensor replicated when the rank is 0 - """ + """Make colo_tensor replicated when the rank is 0""" if not colo_tensor.is_replicate(): pg = colo_tensor.get_process_group() # for the group which contains rank 0 @@ -36,12 +35,11 @@ def gather_tensor(colo_tensor: ColoTensor) -> None: dist.barrier() if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signature + setattr(colo_tensor, "save_ready", True) # set saving signature def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: - """Reversal operation of `gather_tensor`. - """ + """Reversal operation of `gather_tensor`.""" if dist_spec.placement == DistPlacementPattern.REPLICATE: robust_broadcast(colo_tensor.data) else: @@ -57,7 +55,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: colo_tensor.set_dist_spec(dist_spec) else: rep_tensor = ColoTensor( - entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec) + ) rep_tensor.set_dist_spec(dist_spec) with torch.no_grad(): colo_tensor.data.copy_(rep_tensor.data) diff --git a/colossalai/legacy/utils/checkpointing.py b/colossalai/legacy/utils/checkpointing.py index b7b29cc984d6..c068faafbf44 100644 --- a/colossalai/legacy/utils/checkpointing.py +++ b/colossalai/legacy/utils/checkpointing.py @@ -11,7 +11,7 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" from .common import is_using_pp @@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode): return state_dict[0] -def partition_tensor_parallel_state_dict(state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict()): +def partition_tensor_parallel_state_dict( + state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() +): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) group = gpc.get_cpu_group(parallel_mode) @@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict, def gather_tensor_parallel_state_dict( - state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict(), - keep_vars: bool = False, + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, ): dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) @@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict): def gather_pipeline_parallel_state_dict(state_dict): - gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) + gathered_states = ( + [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else None + ) dist.gather_object( state_dict, gathered_states, @@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict): group=gpc.get_cpu_group(ParallelMode.PIPELINE), ) - state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) + state_dict = ( + OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else OrderedDict() + ) return state_dict -def save_checkpoint(file, - epoch: int, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - **kwargs): +def save_checkpoint( + file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs, +): """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary. @@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module): src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] for p in model.parameters(): if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: - group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( - ParallelMode.TENSOR) + group = ( + gpc.get_group(ParallelMode.TENSOR) + if p.device.type == "cuda" + else gpc.get_cpu_group(ParallelMode.TENSOR) + ) dist.broadcast(p, src_rank, group=group) @@ -226,8 +236,9 @@ def load_checkpoint( Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - state_dict = (torch.load(file, map_location=torch.device("cpu")) - if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) + state_dict = ( + torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None + ) # model states model_state = state_dict.pop("model") if state_dict is not None else dict() @@ -246,8 +257,11 @@ def load_checkpoint( dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) if gpc.get_global_rank() == 0: all_error_msgs = list(chain.from_iterable(all_error_msgs)) - raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( - model.__class__.__name__, "\n\t".join(all_error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs) + ) + ) else: raise e diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 35095161c2f2..671bcc3d6ad7 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -80,7 +80,6 @@ def is_using_sequence(): class model_branch_context(object): - def __enter__(self): self.env_status = env.save() @@ -98,16 +97,14 @@ def _calc_l2_norm(grads): if fused_optim is None: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - fused_optim.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], - False # no per-parameter norm + fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm ) return norm @@ -121,7 +118,7 @@ def _calc_lp(grads, norm_type): def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - if torch.is_tensor(norm) and norm.device.type != 'cuda': + if torch.is_tensor(norm) and norm.device.type != "cuda": norm = norm.to(torch.cuda.current_device()) return norm @@ -141,11 +138,11 @@ def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: if len(params) == 0: return 0.0 grads = [p.grad for p in params] - use_cuda_kernel = grads[0].device.type == 'cuda' + use_cuda_kernel = grads[0].device.type == "cuda" if norm_type == inf: local_lp = max([g.abs().max() for g in grads]) elif norm_type == 2.0 and use_cuda_kernel: - local_lp = _calc_l2_norm(grads)**norm_type + local_lp = _calc_l2_norm(grads) ** norm_type else: local_lp = _calc_lp(grads, norm_type) if isinstance(local_lp, torch.Tensor): @@ -202,8 +199,8 @@ def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: assert isinstance(p, ColoParameter) if grad_dtype is None: grad_dtype = p.grad.dtype - assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' - if p.grad.device.type == 'cuda': + assert p.grad.dtype == grad_dtype, f"Expected all grads are {grad_dtype}, got {p.grad.dtype}" + if p.grad.device.type == "cuda": cuda_grad_params.append(p) else: cpu_grad_params.append(p) @@ -221,7 +218,7 @@ def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: norm_type = float(norm_type) total_norm = _compute_grad_lp(parameters, norm_type) if norm_type != inf: - total_norm = total_norm**(1 / norm_type) + total_norm = total_norm ** (1 / norm_type) return total_norm @@ -235,14 +232,15 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: for p in parameters: if p.grad is None: continue - if p.grad.device.type == 'cuda': + if p.grad.device.type == "cuda": cuda_grads.append(p.grad.detach()) else: cpu_grads.append(p.grad.detach()) if len(cuda_grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], - clip_coef) + multi_tensor_applier( + fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef + ) for g in cpu_grads: g.mul_(clip_coef) @@ -284,16 +282,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): for param in parameters: if param.grad is not None: # Make sure the grads are in fp32 - assert param.grad.dtype == torch.float, \ - f'expected gradient to be dtype torch.float, but got {param.grad.type()}' - if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: + assert ( + param.grad.dtype == torch.float + ), f"expected gradient to be dtype torch.float, but got {param.grad.type()}" + if hasattr(param, "colo_attr") and param.colo_attr.sharded_data_tensor.is_sharded: has_zero_shared_param = True params.append(param) if len(params) == 0: enable_cuda_kernels = False else: - enable_cuda_kernels = params[0].grad.device.type == 'cuda' + enable_cuda_kernels = params[0].grad.device.type == "cuda" # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) @@ -307,15 +306,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL), - async_op=False) + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False + ) if has_zero_shared_param: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.DATA), - async_op=False) + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False + ) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] @@ -323,17 +320,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): zero_sharded_grads = [] for p in params: if is_model_parallel_parameter(p): - reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) - elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: + elif hasattr(p, "colo_attr") and p.colo_attr.sharded_data_tensor.is_sharded: zero_sharded_grads.append(p.grad.data) else: no_tensor_parallel_grads.append(p.grad.data) if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type - no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type - zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads) ** norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads) ** norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type else: tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) @@ -358,7 +355,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): total_norm = tensor_parallel_norm + no_tensor_parallel_norm if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) - total_norm = total_norm**(1.0 / norm_type) + total_norm = total_norm ** (1.0 / norm_type) if torch.is_tensor(total_norm): total_norm = total_norm.item() @@ -397,13 +394,14 @@ def count_zeros_fp32(parameters): # Sum across all model-parallel GPUs. ops = [] ops.append( - dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True) + ) if gpc.is_initialized(ParallelMode.PIPELINE): ops.append( - dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) + dist.all_reduce( + total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True + ) + ) for req in ops: req.wait() @@ -420,8 +418,9 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor): def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( - ParallelMode.TENSOR) == 0) + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or ( + gpc.get_local_rank(ParallelMode.TENSOR) == 0 + ) @contextmanager diff --git a/colossalai/legacy/utils/data_sampler/__init__.py b/colossalai/legacy/utils/data_sampler/__init__.py index 12798a94c2d0..677d767667f2 100644 --- a/colossalai/legacy/utils/data_sampler/__init__.py +++ b/colossalai/legacy/utils/data_sampler/__init__.py @@ -1,4 +1,4 @@ from .base_sampler import BaseSampler from .data_parallel_sampler import DataParallelSampler, get_dataloader -__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader'] +__all__ = ["BaseSampler", "DataParallelSampler", "get_dataloader"] diff --git a/colossalai/legacy/utils/data_sampler/base_sampler.py b/colossalai/legacy/utils/data_sampler/base_sampler.py index 89f3bca5b1b5..c6b916fc4870 100644 --- a/colossalai/legacy/utils/data_sampler/base_sampler.py +++ b/colossalai/legacy/utils/data_sampler/base_sampler.py @@ -5,7 +5,6 @@ class BaseSampler(ABC): - def __init__(self, dataset, batch_size): self.dataset = dataset self.batch_size = batch_size diff --git a/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py index 66a5fdd3694d..41d0861e2249 100644 --- a/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py +++ b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py @@ -13,7 +13,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) class DataParallelSampler(Sampler): @@ -44,11 +44,11 @@ def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_ self.num_samples = math.ceil( # `type:ignore` is required because Dataset cannot provide a default __len__ # see NOTE in pytorch/torch/utils/data/sampler.py - (len(self.dataset) - self.num_replicas) / \ - self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] ) else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -65,7 +65,7 @@ def __iter__(self) -> Iterator[T_co]: # set_epoch manually self.epoch += 1 else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] + indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible @@ -76,11 +76,11 @@ def __iter__(self) -> Iterator[T_co]: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -99,14 +99,9 @@ def set_epoch(self, epoch: int) -> None: self.epoch = epoch -def get_dataloader(dataset, - shuffle=False, - seed=1024, - add_sampler=True, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): +def get_dataloader( + dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs +): r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) Note: @@ -144,18 +139,22 @@ def seed_worker(worker_id): random.seed(worker_seed) if sampler is None: - return DataLoader(dataset, - worker_init_fn=seed_worker, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) else: - return DataLoader(dataset, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 360bf0da4a77..2f99a7d2f72e 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -76,8 +76,10 @@ def report_memory_usage(message, logger=None, report_cpu=False): gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) - full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + full_log = ( + f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + ) if report_cpu: # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports @@ -91,7 +93,7 @@ def report_memory_usage(message, logger=None, report_cpu=False): logger.info(full_log) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats() @@ -106,10 +108,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: int: size in byte """ assert isinstance(device, torch.device) - if device.type == 'cpu': + if device.type == "cpu": # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node - if device.type == 'cuda': + if device.type == "cuda": return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION @@ -123,16 +125,16 @@ def colo_device_memory_used(device: torch.device) -> int: Returns: int: memory size in bytes """ - if device.type == 'cpu': + if device.type == "cpu": mem_info = _get_cpu_memory_info() # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used. # Each process consumes the same amount of memory. ret = mem_info.used / gpc.num_processes_on_current_node return ret - elif device.type == 'cuda': + elif device.type == "cuda": ret: int = torch.cuda.memory_allocated(device) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats(device) return ret @@ -145,9 +147,9 @@ def colo_set_process_memory_fraction(ratio: float) -> None: Args: ratio (float): a ratio between 0. ~ 1. """ - if version.parse(torch.__version__) < version.parse('1.8'): - logger = get_dist_logger('colo_set_process_memory_fraction') - logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8') + if version.parse(torch.__version__) < version.parse("1.8"): + logger = get_dist_logger("colo_set_process_memory_fraction") + logger.warning("colo_set_process_memory_fraction failed because torch version is less than 1.8") return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio diff --git a/colossalai/legacy/utils/profiler/extention.py b/colossalai/legacy/utils/profiler/extention.py index 6726a683cc05..c07c6046bb1c 100644 --- a/colossalai/legacy/utils/profiler/extention.py +++ b/colossalai/legacy/utils/profiler/extention.py @@ -2,7 +2,6 @@ class ProfilerExtension(ABC): - @abstractmethod def prepare_trace(self): pass diff --git a/colossalai/legacy/utils/profiler/legacy/__init__.py b/colossalai/legacy/utils/profiler/legacy/__init__.py index 88beed86d7de..88b4201d8bf3 100644 --- a/colossalai/legacy/utils/profiler/legacy/__init__.py +++ b/colossalai/legacy/utils/profiler/legacy/__init__.py @@ -3,4 +3,4 @@ from .pcie_profiler import PcieProfiler from .prof_utils import BaseProfiler, ProfilerContext -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] +__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"] diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index bb7e2654c740..ad54b989f412 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -20,14 +20,14 @@ def _get_code_location(depth: int): upper_frame = inspect.stack()[i] function_name = inspect.stack()[i - 1].function ret.append(upper_frame.filename) - ret.append('(') + ret.append("(") ret.append(str(upper_frame.lineno)) - ret.append('): ') + ret.append("): ") ret.append(function_name) if i != length - 1: - ret.append('\n') + ret.append("\n") - return ''.join(ret) + return "".join(ret) torch_all_reduce = dist.all_reduce @@ -42,7 +42,7 @@ class CommEvent(object): volume recording. """ - def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0): self.self_count = count self.self_comm_vol = comm_vol self.self_cuda_time = cuda_time @@ -54,8 +54,7 @@ def add(self, rhs): class CommProfiler(BaseProfiler): - """Communication profiler. Records all communication events. - """ + """Communication profiler. Records all communication events.""" def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): super().__init__(profiler_name="Collective_Communication", priority=0) @@ -114,8 +113,10 @@ def append(s: str = None): res.append(sep) if self.warn_flag: - append("Warning: there exists multiple communication operations in the same time. As a result, " - "the profiling result is not accurate.") + append( + "Warning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate." + ) if self.total_cuda_time == 0: return "No collective communication has been called yet!" @@ -126,24 +127,29 @@ def append(s: str = None): append("total number of calls: {}".format(self.total_count)) append("All events:") - separation = '-' * 74 - row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + separation = "-" * 74 + row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2 append(separation) - append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls")) append(separation) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) for location, event in show_list: append(location) append( - row_format.format('', _format_time(event.self_cuda_time), - '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), - _format_memory(event.self_comm_vol), - _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + row_format.format( + "", + _format_time(event.self_cuda_time), + "{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), + event.self_count, + ) + ) append() - return ''.join(res) + return "".join(res) @property def has_aync_op(self): @@ -195,8 +201,7 @@ def wait_async_op(self): class CommHandler(object): - """Communication handler. A dummy handler to wait aync operations. - """ + """Communication handler. A dummy handler to wait aync operations.""" def __init__(self, profiler: CommProfiler): super().__init__() @@ -212,11 +217,9 @@ def async_check(profiler: CommProfiler): profiler.wait_async_op() -def all_reduce(tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def all_reduce( + tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -231,12 +234,14 @@ def all_reduce(tensor: torch.Tensor, profiler.close_profiler(group) -def reduce_scatter(output: torch.Tensor, - input_list: List[torch.Tensor], - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def reduce_scatter( + output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -254,11 +259,13 @@ def reduce_scatter(output: torch.Tensor, profiler.close_profiler(group) -def all_gather(tensor_list: List[torch.Tensor], - tensor: torch.Tensor, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def all_gather( + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -276,11 +283,9 @@ def all_gather(tensor_list: List[torch.Tensor], profiler.close_profiler(group) -def broadcast(tensor: torch.Tensor, - src: int, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def broadcast( + tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() @@ -293,12 +298,14 @@ def broadcast(tensor: torch.Tensor, profiler.close_profiler(group) -def reduce(tensor: torch.Tensor, - dst: int, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def reduce( + tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() diff --git a/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py index 514d3c6fabfa..10a3f8dfc43b 100644 --- a/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py @@ -18,6 +18,7 @@ def _get_size(dtype: str): def _get_numel(my_list: List[int]) -> int: from functools import reduce from operator import mul + return reduce(mul, my_list) @@ -27,12 +28,11 @@ def _reduce_location(locations: List[str]) -> str: ret.append(lo) ret.append("\n") ret = ret[:-1] - return ''.join(ret) + return "".join(ret) class PcieEvent(object): - """Pcie Event. - """ + """Pcie Event.""" def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): self.count = count @@ -73,12 +73,9 @@ def reset(self): self.profiler = None def enable(self): - self.profiler = profile(enabled=True, - use_cuda=True, - use_cpu=True, - use_kineto=True, - record_shapes=True, - with_stack=True) + self.profiler = profile( + enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True + ) self.profiler.__enter__() def disable(self): @@ -92,15 +89,15 @@ def disable(self): if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: continue current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) - code_location = _reduce_location(event.stack[:self.depth]) + code_location = _reduce_location(event.stack[: self.depth]) if code_location in self.ops_record: self.ops_record[code_location].add(current_comm_event) else: self.ops_record[code_location] = current_comm_event - elif 'Memcpy HtoD' in event.name: + elif "Memcpy HtoD" in event.name: self.h2d_count += 1 self.h2d_time += event.cuda_time_total - elif 'Memcpy DtoH' in event.name: + elif "Memcpy DtoH" in event.name: self.d2h_count += 1 self.d2h_time += event.cuda_time_total @@ -132,19 +129,25 @@ def append(s: str = None): append("Possible data transmission events in PCIE:") - separation = '-' * 62 - row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + separation = "-" * 62 + row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2 append(separation) - append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls")) append(separation) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) for location, event in show_list: append(location) append( - row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), - _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + row_format.format( + "", + _format_time(event.cuda_time), + _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), + event.count, + ) + ) append() - return ''.join(res) + return "".join(res) diff --git a/colossalai/legacy/utils/profiler/legacy/prof_utils.py b/colossalai/legacy/utils/profiler/legacy/prof_utils.py index 9b948c9ec1cd..95eecf0715e7 100644 --- a/colossalai/legacy/utils/profiler/legacy/prof_utils.py +++ b/colossalai/legacy/utils/profiler/legacy/prof_utils.py @@ -11,10 +11,10 @@ def _format_time(time_us): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 if time_us >= US_IN_SECOND: - return '{:.3f}s'.format(time_us / US_IN_SECOND) + return "{:.3f}s".format(time_us / US_IN_SECOND) if time_us >= US_IN_MS: - return '{:.3f}ms'.format(time_us / US_IN_MS) - return '{:.3f}us'.format(time_us) + return "{:.3f}ms".format(time_us / US_IN_MS) + return "{:.3f}us".format(time_us) # copied from high version pytorch to support low version @@ -23,28 +23,27 @@ def _format_memory(nbytes): KB = 1024 MB = 1024 * KB GB = 1024 * MB - if (abs(nbytes) >= GB): - return '{:.2f} GB'.format(nbytes * 1.0 / GB) - elif (abs(nbytes) >= MB): - return '{:.2f} MB'.format(nbytes * 1.0 / MB) - elif (abs(nbytes) >= KB): - return '{:.2f} KB'.format(nbytes * 1.0 / KB) + if abs(nbytes) >= GB: + return "{:.2f} GB".format(nbytes * 1.0 / GB) + elif abs(nbytes) >= MB: + return "{:.2f} MB".format(nbytes * 1.0 / MB) + elif abs(nbytes) >= KB: + return "{:.2f} KB".format(nbytes * 1.0 / KB) else: - return str(nbytes) + ' B' + return str(nbytes) + " B" def _format_bandwidth(volume: float or int, time_us: int): - sec_div_mb = (1000.0 / 1024.0)**2 + sec_div_mb = (1000.0 / 1024.0) ** 2 mb_per_sec = volume / time_us * sec_div_mb if mb_per_sec >= 1024.0: - return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + return "{:.3f} GB/s".format(mb_per_sec / 1024.0) else: - return '{:.3f} MB/s'.format(mb_per_sec) + return "{:.3f} MB/s".format(mb_per_sec) class BaseProfiler(ABC): - def __init__(self, profiler_name: str, priority: int): self.name = profiler_name self.priority = priority @@ -111,8 +110,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def to_tensorboard(self, writer): from torch.utils.tensorboard import SummaryWriter - assert isinstance(writer, SummaryWriter), \ - f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + assert isinstance( + writer, SummaryWriter + ), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}." for prof in self.profilers: prof.to_tensorboard(writer) @@ -124,7 +124,7 @@ def to_file(self, log_dir: Union[str, Path]): if not log_dir.exists(): log_dir.mkdir(parents=True, exist_ok=True) for prof in self.profilers: - log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log") prof.to_file(log_file) def show(self): diff --git a/colossalai/legacy/utils/profiler/profiler.py b/colossalai/legacy/utils/profiler/profiler.py index 0827f06b586c..b7a75f25d951 100644 --- a/colossalai/legacy/utils/profiler/profiler.py +++ b/colossalai/legacy/utils/profiler/profiler.py @@ -120,26 +120,30 @@ def trace_handler(prof): p.step() """ - def __init__(self, - *, - activities: Optional[Iterable[ProfilerActivity]] = None, - schedule: Optional[Callable[[int], ProfilerAction]] = None, - on_trace_ready: Optional[Callable[..., Any]] = None, - engine: Optional[Engine] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - with_modules: bool = False, - profile_stateful_tensor_memory: bool = False) -> None: - super().__init__(activities=activities, - schedule=schedule, - on_trace_ready=on_trace_ready, - record_shapes=record_shapes, - profile_memory=profile_memory, - with_stack=with_stack, - with_flops=with_flops, - with_modules=with_modules) + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False, + ) -> None: + super().__init__( + activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + ) self._logger = get_dist_logger() self.extentions: List[ProfilerExtension] = [] if profile_stateful_tensor_memory: @@ -149,9 +153,9 @@ def __init__(self, self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) def prepare_trace(self) -> None: - if hasattr(super(), 'prepare_trace'): + if hasattr(super(), "prepare_trace"): super().prepare_trace() - elif hasattr(super(), '_start_warmup'): + elif hasattr(super(), "_start_warmup"): super()._start_warmup() for ext in self.extentions: ext.prepare_trace() @@ -160,9 +164,9 @@ def _start_warmup(self): self.prepare_trace() def start_trace(self): - if hasattr(super(), '_start_trace'): + if hasattr(super(), "_start_trace"): super()._start_trace() - elif hasattr(super(), 'start_trace'): + elif hasattr(super(), "start_trace"): super().start_trace() for ext in self.extentions: ext.start_trace() @@ -171,9 +175,9 @@ def _start_trace(self): self.start_trace() def stop_trace(self): - if hasattr(super(), '_stop_trace'): + if hasattr(super(), "_stop_trace"): super()._stop_trace() - elif hasattr(super(), 'stop_trace'): + elif hasattr(super(), "stop_trace"): super().stop_trace() for ext in self.extentions: ext.stop_trace() @@ -186,15 +190,15 @@ def export_chrome_trace(self, path: str): Exports the collected trace in Chrome JSON format. """ assert self.profiler - fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) fp.close() retvalue = self.profiler.export_chrome_trace(fp.name) with open(fp.name) as fin: trace = json.load(fin) for ext in self.extentions: trace = ext.extend_chrome_trace(trace) - open_func = gzip.open if path.endswith('.gz') else open - with open_func(path, 'wt') as fout: + open_func = gzip.open if path.endswith(".gz") else open + with open_func(path, "wt") as fout: json.dump(trace, fout) os.remove(fp.name) diff --git a/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py index f3bb66ced583..9247a9b80772 100644 --- a/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py @@ -22,11 +22,11 @@ def get_timestamp_us(): def generic_instant_event(name, pid, tid, timestamp, args): - return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args} class StatefulTensorMemoryEvent: - EVENT_NAME = '[statefulTensorMemory]' + EVENT_NAME = "[statefulTensorMemory]" def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: self.pid = os.getpid() @@ -37,22 +37,23 @@ def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None self.bytes = bytes_ def state_dict(self): - return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { - 'Device Type': self.device_type.value, - 'Device Id': self.device_id, - 'Bytes': self.bytes - }) + return generic_instant_event( + StatefulTensorMemoryEvent.EVENT_NAME, + self.pid, + self.tid, + self.timestamp, + {"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes}, + ) class StatefulTensorMemoryTracer: - def __init__(self) -> None: self.events: List[StatefulTensorMemoryEvent] = [] self._tracing = False def sample(self): - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] - cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] + cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"] timestamp = get_timestamp_us() if self._tracing: self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) @@ -70,7 +71,6 @@ def state_dict(self): class StatefulTensorMemoryTracerHook(BaseOpHook): - def __init__(self, tracer: StatefulTensorMemoryTracer): super().__init__() self.tracer = tracer @@ -104,7 +104,6 @@ def disable(self): class StatefulTensorMemoryProfilerExtention(ProfilerExtension): - def __init__(self, engine: Engine) -> None: self.engine = engine self.tracer = StatefulTensorMemoryTracer() @@ -131,5 +130,5 @@ def stop_trace(self): # self.hook_registered = False def extend_chrome_trace(self, trace: dict) -> dict: - trace['traceEvents'].extend(self.tracer.state_dict()) + trace["traceEvents"].extend(self.tracer.state_dict()) return trace diff --git a/colossalai/legacy/zero/__init__.py b/colossalai/legacy/zero/__init__.py index 3783d38e61b2..760fd529f3a6 100644 --- a/colossalai/legacy/zero/__init__.py +++ b/colossalai/legacy/zero/__init__.py @@ -11,8 +11,9 @@ from .sharded_optim import ShardedOptimizerV2 -def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, - optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: +def convert_to_zero_v2( + model: nn.Module, optimizer: torch.optim.Optimizer, model_config, optimizer_config +) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: """ A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading @@ -25,12 +26,12 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model :rtype: Tuple """ - logger = get_dist_logger('convert_to_zero_v2') + logger = get_dist_logger("convert_to_zero_v2") - logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + logger.info(f"optimizer_config is {optimizer_config}", ranks=[0]) if optimizer_config is None: optimizer_config = dict() - logger.info(f'model_config is {model_config}', ranks=[0]) + logger.info(f"model_config is {model_config}", ranks=[0]) if model_config is None: model_config = dict() @@ -40,6 +41,12 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model __all__ = [ - 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', - 'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy' + "convert_to_zero_v2", + "ShardedModelV2", + "ShardedOptimizerV2", + "ZeroInitContext", + "no_shard_zero_context", + "no_shard_zero_decrator", + "TensorShardStrategy", + "BucketTensorShardStrategy", ] diff --git a/colossalai/legacy/zero/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py index 754ae9bc0044..b272980d34d8 100644 --- a/colossalai/legacy/zero/gemini/__init__.py +++ b/colossalai/legacy/zero/gemini/__init__.py @@ -4,6 +4,11 @@ from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy __all__ = [ - 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', - 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' + "StatefulTensorMgr", + "StatefulTensor", + "CPUTensorPlacementPolicy", + "CUDATensorPlacementPolicy", + "AutoTensorPlacementPolicy", + "register_ophooks_recursively", + "BaseOpHook", ] diff --git a/colossalai/legacy/zero/gemini/gemini_context.py b/colossalai/legacy/zero/gemini/gemini_context.py index 9a7da6b80fba..9e82d948fba7 100644 --- a/colossalai/legacy/zero/gemini/gemini_context.py +++ b/colossalai/legacy/zero/gemini/gemini_context.py @@ -2,16 +2,15 @@ class GeminiMemoryManager(object): - def __init__(self, states_cls: EnumMeta): super().__init__() self.states_cls = states_cls - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances self.total_mem = dict() self.state_mem = dict() - self.state_mem['cpu'] = dict() - self.state_mem['cuda'] = dict() + self.state_mem["cpu"] = dict() + self.state_mem["cuda"] = dict() self.reset() @@ -20,15 +19,15 @@ def total_number(self): return self._cnter def reset(self): - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances - self.total_mem['cpu'] = 0 # memory occupation of instances in cpu - self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + self.total_mem["cpu"] = 0 # memory occupation of instances in cpu + self.total_mem["cuda"] = 0 # memory of occupation of instances in cuda # memory conditions for all states for state in self.states_cls: - self.state_mem['cpu'][state] = 0 - self.state_mem['cuda'][state] = 0 + self.state_mem["cpu"][state] = 0 + self.state_mem["cuda"][state] = 0 def register_new_instance(self): self._cnter += 1 @@ -37,12 +36,16 @@ def delete_instance(self): self._cnter -= 1 def print_info(self): - print(f"Total number: {self.total_number}", - f"Total CPU memory occupation: {self.total_mem['cpu']}", - f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", - sep='\n') + print( + f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep="\n", + ) for state in self.states_cls: - print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", - f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", - sep='\n') + print( + f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep="\n", + ) diff --git a/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py index d68a9dc6458f..4129b14bcae9 100644 --- a/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py @@ -22,7 +22,7 @@ def post_fwd_exec(self, module: torch.nn.Module, *args): def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, '_sharded_grad') + assert hasattr(param, "_sharded_grad") param._sharded_grad.setup() def post_bwd_exec(self, module: torch.nn.Module, input): diff --git a/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py index 6b76a2116a49..e0c83eec0445 100644 --- a/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py @@ -19,25 +19,25 @@ def niter(self): def pre_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_bwd_exec(self, module: torch.nn.Module, input): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() diff --git a/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py index eebcf86e0e58..57076063cb3f 100644 --- a/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py @@ -15,8 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class GradMemStats(): - +class GradMemStats: def __init__(self) -> None: self.unreleased_grad_flag = {} self.unreleased_grad_volume = 0 @@ -26,8 +25,7 @@ def clear(self): self.unreleased_grad_volume = 0 -class GradMemTracerHook(): - +class GradMemTracerHook: def __init__(self, grad_stats: GradMemStats): self.grad_hook_list = [] self._grad_stats = grad_stats @@ -50,7 +48,6 @@ def remove_grad_hook(self): class ParamMemTracerHook(ColoParamOpHook): - def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD @@ -79,10 +76,9 @@ def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): if cur_dev == "cpu": if p.grad is not None and p.grad.device.type == "cpu": raise NotImplementedError("Only run in forward propagation") - p.data = torch.empty(p.data.shape, - device="cuda", - dtype=p.data.dtype, - requires_grad=p.data.requires_grad) + p.data = torch.empty( + p.data.shape, device="cuda", dtype=p.data.dtype, requires_grad=p.data.requires_grad + ) elif cur_dev == "cuda": alloc_storage(p.data) diff --git a/colossalai/legacy/zero/gemini/ophooks/utils.py b/colossalai/legacy/zero/gemini/ophooks/utils.py index f88ad2b00e9e..057906156d8d 100644 --- a/colossalai/legacy/zero/gemini/ophooks/utils.py +++ b/colossalai/legacy/zero/gemini/ophooks/utils.py @@ -48,7 +48,6 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs): class PreBackwardFunction(torch.autograd.Function): - @staticmethod def forward(ctx, module, pre_backward_function, outputs): ctx.module = module @@ -64,7 +63,6 @@ def backward(ctx, *args): class PostBackwardFunction(torch.autograd.Function): - @staticmethod def forward(ctx, module, pre_backward_function, output): ctx.module = module @@ -84,16 +82,15 @@ def backward(ctx, *args): return (None, None) + args -def register_ophooks_recursively(module: torch.nn.Module, - ophook_list: List[BaseOpHook], - name: str = "", - filter_fn: Optional[Callable] = None): +def register_ophooks_recursively( + module: torch.nn.Module, ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None +): r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD.""" assert isinstance(module, torch.nn.Module) assert isinstance(ophook_list, (list, tuple)) - assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' + assert len(ophook_list) > 0, "expected at least 1 hook in the argument ophook_list but found 0" for hook in ophook_list: - assert (isinstance(hook, BaseOpHook)) + assert isinstance(hook, BaseOpHook) # Add hooks for submodules for child_name, child in module.named_children(): @@ -118,7 +115,6 @@ def _post_forward_module_hook(submodule, *args): hook.post_fwd_exec(submodule, *args) def _pre_backward_module_hook(submodule, inputs, output): - def _run_before_backward_function(submodule): for hook in ophook_list: assert isinstance(submodule, torch.nn.Module) @@ -127,7 +123,6 @@ def _run_before_backward_function(submodule): return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) def _post_backward_module_hook(submodule, inputs): - def _run_after_backward_function(submodule): for hook in ophook_list: assert isinstance(submodule, torch.nn.Module) diff --git a/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py index 84f32be358e3..91c7bdc2961b 100644 --- a/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py +++ b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py @@ -5,7 +5,6 @@ class BaseParamHookMgr(object): - def __init__(self, param_list: List[torch.nn.Parameter]) -> None: r""" register backward hook on every parameters of module @@ -23,9 +22,9 @@ def register_backward_hooks(self, hook_call: Callable) -> None: ``` """ if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled + return # don't register grad hooks if grad isn't enabled for p in self._param_list: - if p.requires_grad and not hasattr(p, '_base_param_hook'): + if p.requires_grad and not hasattr(p, "_base_param_hook"): handle = p.register_hook(functools.partial(hook_call, p)) p._base_param_hook = handle @@ -35,5 +34,5 @@ def remove_hooks(self) -> None: """ for p in self._param_list: - if p.requires_grad and hasattr(p, '_base_param_hook'): + if p.requires_grad and hasattr(p, "_base_param_hook"): p._base_param_hook.remove() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor.py b/colossalai/legacy/zero/gemini/stateful_tensor.py index 1619ae40798d..668d344132d0 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor.py @@ -25,13 +25,14 @@ class StatefulTensor(object): https://arxiv.org/abs/2108.05818 """ + # Global Stateful Tensor Manager GST_MGR = GeminiMemoryManager(TensorState) def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: self._state = state self._payload = None - self._payload_size = 0 # byte size of current payload + self._payload_size = 0 # byte size of current payload StatefulTensor.GST_MGR.register_new_instance() @@ -47,7 +48,7 @@ def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorS def data_ptr(self): if self._payload is None: - return 0 # if a tensor has no storage, 0 should be returned + return 0 # if a tensor has no storage, 0 should be returned return self._payload.data_ptr() def set_null(self) -> None: @@ -80,7 +81,7 @@ def move_to(self, device: Union[torch.device, int]): assert self.state is not TensorState.FREE, "Can't move free stateful tensor" if not isinstance(device, torch.device): - to_device = torch.device('cuda', device) + to_device = torch.device("cuda", device) else: to_device = device @@ -97,7 +98,6 @@ def payload_copy(self, tensor) -> None: self._payload.view(-1).copy_(tensor.view(-1)) def payload_reset(self, tensor) -> None: - assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" if self.payload is not None: @@ -168,8 +168,7 @@ def __release(self): self._payload_size = 0 def __trans_state_update(self, from_state: TensorState, to_state: TensorState): - """Update global manager when changing the state of a tensor - """ + """Update global manager when changing the state of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size device_type = self.device.type @@ -189,8 +188,7 @@ def __trans_state_update(self, from_state: TensorState, to_state: TensorState): manager.total_mem[device_type] -= size def __trans_device_update(self, from_type: str, to_type: str): - """Update global manager when changing the device of a tensor - """ + """Update global manager when changing the device of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size state = self.state diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index 4f9ea7c6d520..19f77d4305af 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,14 +3,11 @@ from time import time from typing import List -import torch - -from colossalai.logging import get_dist_logger from colossalai.utils.cuda import get_current_device from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class StatefulTensorMgr(object): @@ -44,8 +41,7 @@ def start_iter(self): pass def finish_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" self._warmup = False self._compute_idx = -1 self._cpu_gpu_move_volume = 0 @@ -53,19 +49,21 @@ def finish_iter(self): self._evict_time = 0 def adjust_layout(self) -> None: - """ Adjust the layout of stateful tensor according to the information provided + """Adjust the layout of stateful tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE - cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] + cuda_demand = StatefulTensor.GST_MGR.state_mem["cpu"][TensorState.COMPUTE] start = time() move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) self._layout_time += time() - start - vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._tensor_placement_policy.evict_tensors( + hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._cpu_gpu_move_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA @@ -92,10 +90,10 @@ def _get_layout_info(self, compute_idx: int, warmup: bool): if tensor.state == TensorState.FREE: continue - if tensor.device.type == 'cuda': + if tensor.device.type == "cuda": if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) - elif tensor.device.type == 'cpu': + elif tensor.device.type == "cpu": if tensor.state == TensorState.COMPUTE: move_to_cuda_tensor_list.append(tensor) else: diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 275933ec2cfb..3aca80cfe56a 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -10,11 +10,10 @@ from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class TensorPlacementPolicy(ABC): - def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: self.device: Optional[torch.device] = device self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector @@ -25,9 +24,8 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CPUTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + super().__init__(torch.device("cpu"), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: volume = 0 @@ -38,9 +36,8 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CUDATensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: @@ -48,7 +45,6 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class AutoTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: super().__init__(None, mem_stats_collector=mem_stats_collector) # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase @@ -56,13 +52,15 @@ def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> N self._warmup_non_model_data_ratio: float = 0.8 self._steady_cuda_cap_ratio: float = 0.9 - def evict_tensors(self, - hold_cuda_tensor_list: List[StatefulTensor], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: List[StatefulTensor] = [], - compute_idx: int = 0, - **kwargs) -> int: + def evict_tensors( + self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs, + ) -> int: """ Evict tensors from CUDA device. @@ -81,13 +79,13 @@ def evict_tensors(self, """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda'] + used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data @@ -99,15 +97,16 @@ def evict_tensors(self, to_free_cuda_model_data = cuda_demand - avail_cuda_model_data to_free_tensor_list = hold_cuda_tensor_list if not warmup: - to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx, - tuple(compute_list)) + to_free_tensor_list = self._sort_hold_cuda_tensors( + tuple(hold_cuda_tensor_list), compute_idx, tuple(compute_list) + ) # print(self._sort_hold_cuda_tensors.cache_info()) end = time() for t in to_free_tensor_list: if freed_cuda_model_data >= to_free_cuda_model_data: break freed_cuda_model_data += t.payload_size - colo_model_data_tensor_move_inline(t, torch.device('cpu')) + colo_model_data_tensor_move_inline(t, torch.device("cpu")) if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" @@ -126,14 +125,13 @@ def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_ class TensorPlacementPolicyFactory: - @staticmethod def create(policy_name: str) -> Type[TensorPlacementPolicy]: - if policy_name == 'cpu': + if policy_name == "cpu": return CPUTensorPlacementPolicy - elif policy_name == 'cuda': + elif policy_name == "cuda": return CUDATensorPlacementPolicy - elif policy_name == 'auto': + elif policy_name == "auto": return AutoTensorPlacementPolicy else: raise TypeError(f"Unknown tensor placement policy {policy_name}") diff --git a/colossalai/legacy/zero/gemini/tensor_utils.py b/colossalai/legacy/zero/gemini/tensor_utils.py index 843e330ee2c6..6e51dee6ef94 100644 --- a/colossalai/legacy/zero/gemini/tensor_utils.py +++ b/colossalai/legacy/zero/gemini/tensor_utils.py @@ -30,16 +30,17 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[ cuda_use, cpu_use = 0, 0 mem_use = t.storage().size() * t.element_size() - if t.device.type == 'cuda': + if t.device.type == "cuda": cuda_use += mem_use - elif t.device.type == 'cpu': + elif t.device.type == "cpu": cpu_use += mem_use return cuda_use, cpu_use -def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, - torch.Tensor]) -> None: +def colo_model_data_tensor_move( + src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, torch.Tensor] +) -> None: """ A colossal API for model data tensor move. The src and target tensors could be resident on both CPU and GPU. @@ -71,8 +72,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_ src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) -def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, - int]) -> None: +def colo_model_data_tensor_move_inline( + t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, int] +) -> None: """ move a tensor to the target_device Args: @@ -80,14 +82,14 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t target_device: a target device, if type is int, it the index of cuda card. """ if not isinstance(target_device, torch.device): - target_device = torch.device(f'cuda:{target_device}') + target_device = torch.device(f"cuda:{target_device}") if isinstance(t, torch.Tensor): t.data = t.data.to(target_device) elif isinstance(t, StatefulTensor): t.move_to(target_device) else: - raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_tensor_move_inline dose not accept type {type(t)}") def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: @@ -100,9 +102,9 @@ def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: if isinstance(t, torch.Tensor): t.data = t.data.cpu() elif isinstance(t, StatefulTensor): - t.move_to(torch.device('cpu')) + t.move_to(torch.device("cpu")) else: - raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_move_to_cpu dose not accept type {type(t)}") def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: diff --git a/colossalai/legacy/zero/init_ctx/__init__.py b/colossalai/legacy/zero/init_ctx/__init__.py index 0a6f81566a9d..28ce72a18b31 100644 --- a/colossalai/legacy/zero/init_ctx/__init__.py +++ b/colossalai/legacy/zero/init_ctx/__init__.py @@ -1,3 +1,3 @@ from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator -__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] +__all__ = ["ZeroInitContext", "no_shard_zero_context", "no_shard_zero_decrator"] diff --git a/colossalai/legacy/zero/init_ctx/init_context.py b/colossalai/legacy/zero/init_ctx/init_context.py index 4a7e46408583..6c5a8122ef80 100644 --- a/colossalai/legacy/zero/init_ctx/init_context.py +++ b/colossalai/legacy/zero/init_ctx/init_context.py @@ -39,7 +39,7 @@ def __post_init__(self): assert self.is_replicated, "Non-replicated parameters can't be sharded." if self.is_replicated and not self.shard_param: - assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda." + assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda." class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): @@ -59,15 +59,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ - def __init__(self, - target_device: torch.device, - shard_strategy: BaseShardStrategy, - seed: int = 2**10 - 1, - shard_param: bool = False, - default_dtype: Optional[torch.dtype] = None, - bf16: bool = False, - model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): - + def __init__( + self, + target_device: torch.device, + shard_strategy: BaseShardStrategy, + seed: int = 2**10 - 1, + shard_param: bool = False, + default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long), + ): super().__init__(default_dtype=default_dtype) self.shard_strategy = shard_strategy self.param_list = [] @@ -103,7 +104,7 @@ def calc_fanin_fanout(tensor: torch.Tensor): assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters" # get correct shape of input tensor - if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded: + if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded: tensor_shape = tensor.shape else: tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape @@ -137,13 +138,16 @@ def _pre_context_exec(self): self.module_load_from_state_dict = nn.Module._load_from_state_dict shard_strategy = self.shard_strategy if self.config.shard_param else None - nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, - shard_strategy=shard_strategy) + nn.Module._load_from_state_dict = functools.partialmethod( + ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy + ) self.module_state_dict = nn.Module.state_dict - nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, - shard_strategy=shard_strategy, - state_dict_func=self.module_state_dict, - process_group=self.dp_process_group) + nn.Module.state_dict = functools.partialmethod( + ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group, + ) # reserve rng states self.cpu_rng_state = torch.get_rng_state() @@ -152,16 +156,15 @@ def _pre_context_exec(self): # set new seed for initialization, since we initialize sharded tensor separately # we don't want all processes have the same seed # otherwise all sharded tensors are same after init - offset = self.seed + 1 # we want to have more 1 in binary format seed + offset = self.seed + 1 # we want to have more 1 in binary format seed torch.manual_seed(self.seed + offset * dist.get_rank()) def _post_context_exec(self): - """The callback function when exiting context. - """ + """The callback function when exiting context.""" # broadcast replicated no-shard parameters src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] for param in self.param_list: - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) param.colo_attr.set_data_none() @@ -193,7 +196,7 @@ def half_fn(t: torch.Tensor): for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): continue self.param_numel[param] = param.numel() @@ -216,7 +219,7 @@ def half_fn(t: torch.Tensor): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.data_payload # set param.data to payload + param.data = param.colo_attr.data_payload # set param.data to payload # mark whether the param is replicated param.colo_attr.is_replicated = self.is_replicated @@ -251,15 +254,13 @@ def hijack_context_config(self, **kwargs): def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: - return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), - is_replicated=is_replicated, - shard_param=False) + return ZeroContextMgr().hijack_context_config( + target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False + ) def no_shard_zero_decrator(is_replicated: bool = True): - def _wrapper(init_func): - def _no_shard(*args, **kwargs): with no_shard_zero_context(is_replicated): ret = init_func(*args, **kwargs) diff --git a/colossalai/legacy/zero/shard_utils/__init__.py b/colossalai/legacy/zero/shard_utils/__init__.py index 5e5d63a7e768..945c77a412c1 100644 --- a/colossalai/legacy/zero/shard_utils/__init__.py +++ b/colossalai/legacy/zero/shard_utils/__init__.py @@ -2,4 +2,4 @@ from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] +__all__ = ["BaseShardStrategy", "TensorShardStrategy", "BucketTensorShardStrategy"] diff --git a/colossalai/legacy/zero/shard_utils/base_shard_strategy.py b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py index 9fb80f57ae77..13e6f0e48298 100644 --- a/colossalai/legacy/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py @@ -7,10 +7,8 @@ class BaseShardStrategy(ABC): - def __init__(self) -> None: - """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. - """ + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.""" super().__init__() @abstractmethod diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index 1f7baad57816..b9d3071a877e 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -18,7 +18,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): """ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): - tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] if len(tensor_list) == 0: return @@ -40,8 +39,8 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. buffer_list = [buffer.to(target_device) for buffer in buffer_list] offset = 0 for i, t in enumerate(tensor_list): - gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] - gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + gathered_payload = [buffer[offset : offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[: t.origin_numel].view(t.origin_shape) t.payload_reset(gathered_payload) t.is_sharded = False offset += tensor_numels[i] diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index cc43907f6655..ebaef774bd06 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -24,7 +24,7 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. self._gather_tensor(t, process_group) def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): - """ Shard tensor among processes. + """Shard tensor among processes. Args: t (ShardedTensor): a tensor to be sharded. @@ -33,9 +33,11 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr """ if t.is_sharded: return - if t.payload.device.type == 'cuda': - assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + if t.payload.device.type == "cuda": + assert t.payload.device == get_current_device(), ( + f"shard tensor on cuda device index {t.payload.device.index}," f" but current cuda device is {get_current_device()}" + ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) t.is_sharded = True diff --git a/colossalai/legacy/zero/sharded_model/__init__.py b/colossalai/legacy/zero/sharded_model/__init__.py index 93120bdc34b4..ecead2f6a657 100644 --- a/colossalai/legacy/zero/sharded_model/__init__.py +++ b/colossalai/legacy/zero/sharded_model/__init__.py @@ -1,3 +1,3 @@ from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModelV2'] +__all__ = ["ShardedModelV2"] diff --git a/colossalai/legacy/zero/sharded_model/_utils.py b/colossalai/legacy/zero/sharded_model/_utils.py index b8a618ef5a0d..100762318593 100644 --- a/colossalai/legacy/zero/sharded_model/_utils.py +++ b/colossalai/legacy/zero/sharded_model/_utils.py @@ -25,7 +25,7 @@ def free_storage(data: torch.Tensor) -> None: @torch.no_grad() def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: """Allocate storage for a tensor.""" - if data.storage().size() == size.numel(): # no need to reallocate + if data.storage().size() == size.numel(): # no need to reallocate return assert data.storage().size() == 0 data.storage().resize_(size.numel()) diff --git a/colossalai/legacy/zero/sharded_model/reduce_scatter.py b/colossalai/legacy/zero/sharded_model/reduce_scatter.py index 4fb507382df9..0f11365515d2 100644 --- a/colossalai/legacy/zero/sharded_model/reduce_scatter.py +++ b/colossalai/legacy/zero/sharded_model/reduce_scatter.py @@ -20,7 +20,6 @@ class Bucket: - def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) self.group = group @@ -35,18 +34,18 @@ def flush(self) -> None: return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: - dist._reduce_scatter_base(self.output_shard[:self.offset], - self.buffer[:, :self.offset].contiguous(), - group=self.group) + dist._reduce_scatter_base( + self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group + ) else: - dist.reduce_scatter(self.output_shard[:self.offset], - list(self.buffer[:, :self.offset].unbind(0)), - group=self.group) + dist.reduce_scatter( + self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group + ) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard - self.buffer[:, :self.offset].zero_() + self.buffer[:, : self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.buffer[0]) @@ -74,12 +73,12 @@ def append(self, tensor_list: List[Tensor], callback_fn: Callable): tensor_size = tensor_list[0].numel() stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) offset = self.offset - self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) + self.buffer[:, offset : offset + tensor_size].copy_(stacked_input) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) + result_view = self.output_shard[offset : offset + tensor_size].view_as(tensor_list[0]) self.callbacks.append(functools.partial(callback_fn, result_view)) @@ -142,8 +141,9 @@ def reduce_scatter_async( """ world_size = group.size() - assert (len(input_list) == world_size - ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + assert ( + len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() @@ -183,7 +183,7 @@ def free(self) -> None: @functools.lru_cache() def _get_shard_size(self, element_size: int, num_shards: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 91c21ccf9516..85f2ac2159f4 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -2,7 +2,6 @@ import functools import itertools from collections import OrderedDict -from copy import deepcopy from typing import Any, Iterator, Optional, Tuple import torch @@ -40,7 +39,7 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" class ShardedModelV2(nn.Module): @@ -78,20 +77,22 @@ class ShardedModelV2(nn.Module): bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. """ - def __init__(self, - module: nn.Module, - shard_strategy: BaseShardStrategy, - process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None, - reduce_scatter_bucket_size_mb: int = 25, - fp32_reduce_scatter: bool = False, - tensor_placement_policy: str = 'cuda', - gradient_predivide_factor: Optional[float] = 1.0, - reuse_fp16_shard: bool = False, - bf16: bool = False, - *args, - **kwargs): - assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' + def __init__( + self, + module: nn.Module, + shard_strategy: BaseShardStrategy, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + fp32_reduce_scatter: bool = False, + tensor_placement_policy: str = "cuda", + gradient_predivide_factor: Optional[float] = 1.0, + reuse_fp16_shard: bool = False, + bf16: bool = False, + *args, + **kwargs, + ): + assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported." super().__init__() self.logger = get_dist_logger() self.bf16 = bf16 @@ -101,13 +102,13 @@ def __init__(self, sharded_cnt = 0 unshard_cnt = 0 for param in submodule.parameters(recurse=False): - assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first." if param.colo_attr.param_is_sharded: sharded_cnt += 1 else: unshard_cnt += 1 - assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' - submodule.param_is_sharded = (sharded_cnt > 0) + assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param" + submodule.param_is_sharded = sharded_cnt > 0 self.sharded_params = [] self.unshard_params = [] @@ -124,7 +125,7 @@ def __init__(self, self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy - self._use_memory_tracer = tensor_placement_policy == 'auto' + self._use_memory_tracer = tensor_placement_policy == "auto" if self._use_memory_tracer: self._memstats_collector = MemStatsCollector() self._start_collect_memstats = disposable(self._memstats_collector.start_collection) @@ -132,18 +133,19 @@ def __init__(self, else: self._memstats_collector = None self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( - tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + tensor_placement_policy + )(mem_stats_collector=self._memstats_collector) - if 'warmup_non_model_data_ratio' in kwargs: - if tensor_placement_policy != 'auto': - self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') + if "warmup_non_model_data_ratio" in kwargs: + if tensor_placement_policy != "auto": + self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement") else: - ratio = kwargs['warmup_non_model_data_ratio'] + ratio = kwargs["warmup_non_model_data_ratio"] self._tensor_placement_policy._warmup_non_model_data_ratio = ratio - self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') + self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement") self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) - param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")] self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) # Register hooks @@ -155,7 +157,7 @@ def __init__(self, self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.fp32_reduce_scatter = fp32_reduce_scatter - self._cpu_offload: bool = tensor_placement_policy != 'cuda' + self._cpu_offload: bool = tensor_placement_policy != "cuda" for param in module.parameters(): # Init `offload_grad` param.colo_attr.offload_grad = self._cpu_offload @@ -164,9 +166,11 @@ def __init__(self, # So we use 1.0 as the default gradient_predivide_factor # However, if you set gradient_predivide_factor to None, we will set # gradient_predivide_factor to a value >= 1.0 automatically - self.gradient_predivide_factor: float = gradient_predivide_factor if \ - gradient_predivide_factor is not None else \ - get_gradient_predivide_factor(self.world_size) + self.gradient_predivide_factor: float = ( + gradient_predivide_factor + if gradient_predivide_factor is not None + else get_gradient_predivide_factor(self.world_size) + ) self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() @@ -194,7 +198,7 @@ def cuda_margin_space(self): def cpu_offload(self): return self._cpu_offload - def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None: """ dummy memory tracer collected information to a file. try: @@ -205,18 +209,18 @@ def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> N exit(0) """ if self._use_memory_tracer: - self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0]) + self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: - with open(filename, 'w+') as f: - f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') - f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') - f.write('CUDA model data (GB)\n') - f.write('\n') - f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) - f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) - f.write('\n') + with open(filename, "w+") as f: + f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") + f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write("CUDA model data (GB)\n") + f.write("\n") + f.write("CUDA non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda"))) + f.write("CPU non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu"))) + f.write("\n") def _pre_forward_operations(self, *args): # the operation will affect the memory tracer behavior in ZeroHook @@ -224,14 +228,14 @@ def _pre_forward_operations(self, *args): self._start_collect_memstats() for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): + if hasattr(p, "colo_attr"): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) self._stateful_tensor_mgr.start_iter() def _post_forward_operations(self): for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): + if hasattr(p, "colo_attr"): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: @@ -261,8 +265,9 @@ def _update_memstats(self): # the way to calculate margin space is based on the assumption that # model data is fixed in cuda during training. # cuda margin space can be used to store OS. - self._cuda_margin_space = colo_device_memory_capacity( - get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + self._cuda_margin_space = ( + colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + ) @torch.no_grad() def _post_backward_operations(self) -> None: @@ -330,7 +335,7 @@ def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Opti """ if grad is None: return - assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' + assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients" if not self._require_backward_grad_sync: return # used to cheat Pytorch, since we can't return None @@ -354,16 +359,19 @@ def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: grad.data.div_(self.gradient_predivide_factor) if self.world_size > 1: grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) - self.reducer.reduce_scatter_async(grad_chunks, - group=self.reduce_scatter_process_group, - callback_fn=functools.partial(self._reduce_scatter_callback, param)) + self.reducer.reduce_scatter_async( + grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param), + ) else: self._reduce_scatter_callback(param, grad) torch.cuda.current_stream().wait_stream(self.comm_stream) def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: - assert isinstance(reduced_grad, - torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" + assert isinstance( + reduced_grad, torch.Tensor + ), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" reduced_grad.data = reduced_grad.data.contiguous().view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. @@ -372,7 +380,6 @@ def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) # FIXME(ver217): refactor the below line when impl eviction policy def _save_grad(self, param: Parameter, grad: torch.Tensor): - # record whether we have overflow self.overflow_counter += torch.isinf(grad).any().item() self.overflow_counter += torch.isnan(grad).any().item() @@ -384,8 +391,9 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): if self.reuse_fp16_shard: # make parameters point to gradient - assert param.colo_attr.saved_grad.is_null( - ), 'Gradient accumulation is not supported when reuse_fp16_shard=True' + assert ( + param.colo_attr.saved_grad.is_null() + ), "Gradient accumulation is not supported when reuse_fp16_shard=True" param.colo_attr.grad_payload_reset(grad.data) # release the memory of param @@ -396,7 +404,6 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): if param.colo_attr.is_replicated: param.colo_attr.sharded_data_tensor.is_sharded = True else: - fp32_grad = cast_tensor_to_fp32(grad) if param.colo_attr.saved_grad.is_null(): @@ -410,39 +417,44 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): def parameters(self, recurse: bool = True) -> Iterator[Parameter]: return self.module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: return self.module.named_parameters(prefix, recurse) - def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - return self._colo_state_dict(destination, - prefix, - keep_vars, - shard_strategy=self.shard_strategy, - state_dict_func=nn.Module.state_dict, - module_to_load=self.module, - sharded_params=self.sharded_params, - process_group=self.process_group) - - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: + def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]": + return self._colo_state_dict( + destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group, + ) + + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None: for name, p in self.named_parameters(): if name in state_dict: - p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, - device=p.colo_attr.data_payload.device)) + p.colo_attr.data_payload_reset( + state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) # Force re-shard p.colo_attr.sharded_data_tensor.is_sharded = False self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) elif strict: - raise RuntimeError(f'Missing key in state_dict: {name}') - - def _colo_state_dict(self, - destination=None, - prefix='', - keep_vars=False, - shard_strategy: Optional[BaseShardStrategy] = None, - state_dict_func=None, - module_to_load=None, - sharded_params=[], - process_group=None) -> 'OrderedDict[str, torch.Tensor]': + raise RuntimeError(f"Missing key in state_dict: {name}") + + def _colo_state_dict( + self, + destination=None, + prefix="", + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None, + ) -> "OrderedDict[str, torch.Tensor]": if len(sharded_params) == 0: for param in self.parameters(): if param.colo_attr.param_is_sharded: @@ -460,15 +472,9 @@ def _colo_state_dict(self, p.colo_attr.set_data_none() return gathered_state_dict - def _colo_load_from_state_dict(self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - shard_strategy=None): + def _colo_load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -512,10 +518,12 @@ def _colo_load_from_state_dict(self, key = prefix + name if key in state_dict: input_param = state_dict[key] - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): param.colo_attr.data_payload_reset( - input_param.to(dtype=param.colo_attr.data_payload.dtype, - device=param.colo_attr.data_payload.device)) + input_param.to( + dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device + ) + ) if shard_strategy is not None: # Force re-shard param.colo_attr.sharded_data_tensor.is_sharded = False @@ -531,19 +539,21 @@ def _colo_load_from_state_dict(self, if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format( - key, input_param.shape, param.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) continue try: with torch.no_grad(): param.copy_(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), - ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(key) @@ -559,8 +569,8 @@ def _colo_load_from_state_dict(self, if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) diff --git a/colossalai/legacy/zero/sharded_model/utils.py b/colossalai/legacy/zero/sharded_model/utils.py index 7a411669900b..cb085f19e6b2 100644 --- a/colossalai/legacy/zero/sharded_model/utils.py +++ b/colossalai/legacy/zero/sharded_model/utils.py @@ -11,7 +11,7 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu Note the other_model has to be the same as self. """ for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): - assert hasattr(zero_param, 'colo_attr') + assert hasattr(zero_param, "colo_attr") shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded if shard_flag: sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 3fc373e5ca44..892e9f31ded4 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook): Warning: this class has been deprecated after version 0.1.12 """ - def __init__(self, - shard_strategy: BaseShardStrategy, - memstarts_collector: Optional[MemStatsCollector] = None, - stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, - process_group: Optional[dist.ProcessGroup] = None): + def __init__( + self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, + process_group: Optional[dist.ProcessGroup] = None, + ): super().__init__() self.logger = get_dist_logger("ZeROHook") self.shard_strategy = shard_strategy @@ -41,7 +43,7 @@ def gather_parameters(self, module: torch.nn.Module): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) @@ -50,7 +52,7 @@ def shard_parameters(self, module: torch.nn.Module): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) @@ -74,10 +76,9 @@ def pre_fwd_exec(self, module: torch.nn.Module, *args): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): - # change tensor state to HOLD_AFTER_FWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) @@ -93,10 +94,9 @@ def pre_bwd_exec(self, module: torch.nn.Module, input, output): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): - # change tensor state to HOLD_AFTER_BWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) @@ -114,5 +114,6 @@ def post_iter(self): if self._stateful_tensor_mgr: self.logger.debug( f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}", - ranks=[0]) + ranks=[0], + ) self._stateful_tensor_mgr.finish_iter() diff --git a/colossalai/legacy/zero/sharded_optim/__init__.py b/colossalai/legacy/zero/sharded_optim/__init__.py index b71a70aeffa4..700fb0eb91d3 100644 --- a/colossalai/legacy/zero/sharded_optim/__init__.py +++ b/colossalai/legacy/zero/sharded_optim/__init__.py @@ -1,3 +1,3 @@ from .sharded_optim_v2 import ShardedOptimizerV2 -__all__ = ['ShardedOptimizerV2'] +__all__ = ["ShardedOptimizerV2"] diff --git a/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py index e21f1cea04df..e73679163fab 100644 --- a/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py @@ -1,6 +1,5 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch from enum import Enum -from os import stat from typing import Dict, Optional, Tuple import torch @@ -74,22 +73,24 @@ class ShardedOptimizerV2(OptimizerWrapper): https://arxiv.org/abs/2108.05818 """ - def __init__(self, - sharded_model: ShardedModelV2, - optimizer: Optimizer, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - dp_process_group: Optional[ProcessGroup] = None, - mp_process_group: Optional[ProcessGroup] = None, - verbose: bool = False) -> None: - assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' - assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' + def __init__( + self, + sharded_model: ShardedModelV2, + optimizer: Optimizer, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + dp_process_group: Optional[ProcessGroup] = None, + mp_process_group: Optional[ProcessGroup] = None, + verbose: bool = False, + ) -> None: + assert isinstance(sharded_model, ShardedModelV2), "model must be wrapped with ShardedModel" + assert not isinstance(optimizer, ShardedOptimizerV2), "Nested ShardedOptimizerV2 is not supported." super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy @@ -97,39 +98,49 @@ def __init__(self, self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( - optimizer, 'num_fp32_shards_per_param', 0) >= 2 - self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') + self._should_move_fp32_shards_h2d: bool = ( + sharded_model.cpu_offload + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optimizer, "num_fp32_shards_per_param", 0) >= 2 + ) + self.device = sharded_model._tensor_placement_policy.device or torch.device("cpu") self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose - self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward + self._grad_prepared: bool = ( + False # this should be set to true when _prepare_grads() and reset to false when backward + ) # Store fp32 param shards self._register_master_weight() - if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, - AutoTensorPlacementPolicy): - self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', - ranks=[0]) + if self.gpu_margin_mem_ratio != 0.0 and not isinstance( + sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy + ): + self._logger.warning( + f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', ranks=[0] + ) if self._verbose: self._logger.debug( - f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) + f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0] + ) self._use_memory_tracer = self.model.use_memory_tracer @@ -138,7 +149,7 @@ def loss_scale(self): return self.grad_scaler.scale.item() def get_memory_usage(self) -> Tuple[int, int]: - """ Get the memory usage of the optimizer. Including master_params (param fp32), + """Get the memory usage of the optimizer. Including master_params (param fp32), momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) Returns: @@ -157,7 +168,7 @@ def update_mem_use(t): for _, p_fp32 in self.master_params.items(): update_mem_use(p_fp32) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): update_mem_use(v) @@ -191,7 +202,6 @@ def clip_grad_norm(self, model: nn.Module, max_norm: float): return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): - self._prepare_grads() # unscale grads if scaled if not self.bf16 and self.optim_state == OptimState.SCALED: @@ -203,7 +213,7 @@ def step(self, *args, **kwargs): self.grad_scaler.update(found_inf) if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') + self._logger.warning("found inf during ShardedOptimV2 step") self._zero_grad(recover_data=True) return @@ -213,14 +223,16 @@ def step(self, *args, **kwargs): gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) ret = self.optim.step(*args, **kwargs) if self._verbose: gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) self._copy_master_model_to_model_fp16() return ret @@ -240,7 +252,7 @@ def _check_overflow(self): def _unscale_grads(self): assert self.optim_state == OptimState.SCALED for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None: p.grad.data.div_(self.loss_scale) self.optim_state = OptimState.UNSCALED @@ -260,16 +272,16 @@ def _zero_grad(self, recover_data: bool = False): # Which leads to wrong accumulation self.optim.zero_grad(set_to_none=True) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: # p.colo_attr.sharded_data_tensor stores grad now # we have to recover fp16 param - reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) + reuse_fp16_shard = p.colo_attr.sharded_data_tensor.payload_size == 0 if recover_data and reuse_fp16_shard: self._copy_master_param_to_param_fp16(p) else: # release saved gradient p.colo_attr.saved_grad.set_null() - self.model.overflow_counter = 0 # set overflow counter to zero + self.model.overflow_counter = 0 # set overflow counter to zero def sync_grad(self): pass @@ -277,8 +289,8 @@ def sync_grad(self): def _register_master_weight(self): self.master_params: Dict[Parameter, StatefulTensor] = {} for group in self.optim.param_groups: - for p in group['params']: - assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + for p in group["params"]: + assert hasattr(p, "colo_attr"), "The parameter must be wrapped with ShardedParam" shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated if shard_flag: # we always shard replicated parameters @@ -296,7 +308,7 @@ def _maybe_move_fp32_shards(self): fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param fp32_shards_used_cuda_margin_mem = 0 for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() @@ -314,7 +326,7 @@ def _prepare_grads(self): if self._grad_prepared: return for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) @@ -335,7 +347,7 @@ def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. # We will not trigger data copy here. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self.master_params[p].trans_state(TensorState.COMPUTE) p.data = self.master_params[p].payload # Now p.data is sharded @@ -346,7 +358,7 @@ def _copy_master_model_to_model_fp16(self): # TODO() improve efficiency by gathering tensors into a chunk and transferring # a chunk. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self._copy_master_param_to_param_fp16(p) def _copy_master_param_to_param_fp16(self, p): @@ -364,7 +376,8 @@ def _copy_master_param_to_param_fp16(self, p): # in order to use copy, otherwise, the sizes of tensor is not compatible if p.colo_attr.data_payload.numel() != p.data.numel(): p.colo_attr.data_payload_reset( - torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) # TODO() optimize this line CPU (fp32) -> GPU (fp16) half_dtype = torch.bfloat16 if self.bf16 else torch.float16 @@ -373,7 +386,7 @@ def _copy_master_param_to_param_fp16(self, p): if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: # We gather full fp16 param here - p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p].trans_state(TensorState.HOLD) @@ -381,18 +394,18 @@ def _copy_master_param_to_param_fp16(self, p): def state_dict(self): optim_state_dict = super().state_dict() scaler_state_dict = self.grad_scaler.state_dict() - optim_state_dict['scaler'] = scaler_state_dict + optim_state_dict["scaler"] = scaler_state_dict return optim_state_dict def load_state_dict(self, *args, **kwargs): - if 'scaler' not in args[0]: - self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + if "scaler" not in args[0]: + self._logger.warning("Missing scaler when loading optimizer state dict", ranks=[0]) else: - scaler_state_dict = args[0].pop('scaler') + scaler_state_dict = args[0].pop("scaler") self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): if isinstance(v, Tensor): diff --git a/colossalai/legacy/zero/sharded_param/__init__.py b/colossalai/legacy/zero/sharded_param/__init__.py index 47e2ce2fa0e0..c7afb95391a4 100644 --- a/colossalai/legacy/zero/sharded_param/__init__.py +++ b/colossalai/legacy/zero/sharded_param/__init__.py @@ -1,4 +1,4 @@ from .sharded_param import ShardedParamV2 from .sharded_tensor import ShardedTensor -__all__ = ['ShardedTensor', 'ShardedParamV2'] +__all__ = ["ShardedTensor", "ShardedParamV2"] diff --git a/colossalai/legacy/zero/sharded_param/sharded_param.py b/colossalai/legacy/zero/sharded_param/sharded_param.py index 454a722cf7e7..22b09d5ff4bb 100644 --- a/colossalai/legacy/zero/sharded_param/sharded_param.py +++ b/colossalai/legacy/zero/sharded_param/sharded_param.py @@ -19,7 +19,6 @@ def get_empty_tensor(device: torch.device, dtype: torch.dtype): class ShardedParamV2(object): - def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) @@ -36,8 +35,7 @@ def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> No self.set_data_none() def get_payload_tensors(self) -> List[StatefulTensor]: - """returns stateful tensors kept by this class. - """ + """returns stateful tensors kept by this class.""" return [self._sharded_data_tensor] def set_data_none(self): diff --git a/colossalai/legacy/zero/sharded_param/sharded_tensor.py b/colossalai/legacy/zero/sharded_param/sharded_tensor.py index 43c7576b93b5..262682d44645 100644 --- a/colossalai/legacy/zero/sharded_param/sharded_tensor.py +++ b/colossalai/legacy/zero/sharded_param/sharded_tensor.py @@ -4,7 +4,6 @@ class ShardedTensor(StatefulTensor): - def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 97fe4f89ded3..521eafa74c30 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -3,23 +3,23 @@ from .logger import DistributedLogger -__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] +__all__ = ["get_dist_logger", "DistributedLogger", "disable_existing_loggers"] -def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: +def get_dist_logger(name: str = "colossalai") -> DistributedLogger: """Get logger instance based on name. The DistributedLogger will create singleton instances, which means that only one logger instance is created per name. Args: name (str): name of the logger, name must be unique - + Returns: :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. """ return DistributedLogger.get_instance(name=name) -def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ["colossalai"]) -> None: """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". Args: diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index fd05ddf1d50f..eb5f28e2a3cf 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -42,12 +42,14 @@ def get_instance(name: str): def __init__(self, name): if name in DistributedLogger.__instances: raise Exception( - 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + "Logger with the same name has been created, you should use colossalai.logging.get_dist_logger" + ) else: handler = None - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") try: from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) handler.setFormatter(formatter) except ImportError: @@ -79,7 +81,7 @@ def __get_call_info(): @staticmethod def _check_valid_logging_level(level: str): - assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' + assert level in ["INFO", "DEBUG", "WARNING", "ERROR"], "found invalid logging level" def set_level(self, level: str) -> None: """Set the logging level @@ -90,7 +92,7 @@ def set_level(self, level: str) -> None: self._check_valid_logging_level(level) self._logger.setLevel(getattr(logging, level)) - def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: + def log_to_file(self, path: Union[str, Path], mode: str = "a", level: str = "INFO", suffix: str = None) -> None: """Save the logs to file Args: @@ -99,8 +101,7 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF level (str): Can only be INFO, DEBUG, WARNING and ERROR. suffix (str): The suffix string of log's name. """ - assert isinstance(path, (str, Path)), \ - f'expected argument path to be type str or Path, but got {type(path)}' + assert isinstance(path, (str, Path)), f"expected argument path to be type str or Path, but got {type(path)}" self._check_valid_logging_level(level) if isinstance(path, str): @@ -110,15 +111,15 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF path.mkdir(parents=True, exist_ok=True) if suffix is not None: - log_file_name = f'rank_{self.rank}_{suffix}.log' + log_file_name = f"rank_{self.rank}_{suffix}.log" else: - log_file_name = f'rank_{self.rank}.log' + log_file_name = f"rank_{self.rank}.log" path = path.joinpath(log_file_name) # add file handler file_handler = logging.FileHandler(path, mode) file_handler.setLevel(getattr(logging, level)) - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) @@ -137,8 +138,8 @@ def info(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('info', message_prefix, ranks) - self._log('info', message, ranks) + self._log("info", message_prefix, ranks) + self._log("info", message, ranks) def warning(self, message: str, ranks: List[int] = None) -> None: """Log a warning message. @@ -148,8 +149,8 @@ def warning(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('warning', message_prefix, ranks) - self._log('warning', message, ranks) + self._log("warning", message_prefix, ranks) + self._log("warning", message, ranks) def debug(self, message: str, ranks: List[int] = None) -> None: """Log a debug message. @@ -159,8 +160,8 @@ def debug(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('debug', message_prefix, ranks) - self._log('debug', message, ranks) + self._log("debug", message_prefix, ranks) + self._log("debug", message, ranks) def error(self, message: str, ranks: List[int] = None) -> None: """Log an error message. @@ -170,5 +171,5 @@ def error(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('error', message_prefix, ranks) - self._log('error', message, ranks) + self._log("error", message_prefix, ranks) + self._log("error", message, ranks) diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py index 559b7038fc35..2637aa8eaaeb 100644 --- a/colossalai/nn/init.py +++ b/colossalai/nn/init.py @@ -1,8 +1,8 @@ import math import warnings -from torch import Tensor import torch.nn as nn +from torch import Tensor def zeros_(): @@ -23,7 +23,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def uniform_(a: float = 0., b: float = 1.): +def uniform_(a: float = 0.0, b: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -38,7 +38,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def normal_(mean: float = 0., std: float = 1.): +def normal_(mean: float = 0.0, std: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the normal distribution .. math:: @@ -47,7 +47,7 @@ def normal_(mean: float = 0., std: float = 1.): Args: mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. - """ + """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.normal_(tensor, mean, std) @@ -55,7 +55,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): +def trunc_normal_(mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0): r"""Return the initializer filling the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -76,7 +76,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_uniform_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -104,23 +104,23 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) - bound = math.sqrt(3.) * std + bound = math.sqrt(3.0) * std return nn.init.uniform_(tensor, -bound, bound) return initializer -def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_normal_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -148,14 +148,14 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) return nn.init.normal_(tensor, 0, std) @@ -163,7 +163,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): +def xavier_uniform_(a: float = math.sqrt(3.0), scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -184,7 +184,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -197,7 +197,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def xavier_normal_(scale: float = 2., gain: float = 1.): +def xavier_normal_(scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal @@ -216,7 +216,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -224,7 +224,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): std = gain * math.sqrt(scale / float(fan)) - return nn.init.normal_(tensor, 0., std) + return nn.init.normal_(tensor, 0.0, std) return initializer @@ -232,7 +232,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def lecun_uniform_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." var = 1.0 / fan_in bound = math.sqrt(3 * var) @@ -244,9 +244,9 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def lecun_normal_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." std = math.sqrt(1.0 / fan_in) - return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) + return nn.init.trunc_normal_(tensor, std=std / 0.87962566103423978) return initializer diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 05333fe965f1..6a5ccff510be 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -5,6 +5,17 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' + "Experts", + "FFNExperts", + "TPExperts", + "Top1Router", + "Top2Router", + "MoeLayer", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "build_ffn_experts", + "MoeModule", + "MoeRouter", + "save_moe_model", + "load_moe_model", ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 37f31c16709b..2f0b7e43673a 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -18,18 +18,18 @@ def build_moe_if_not_prebuilt(): global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() if ctx is not None: @@ -51,7 +51,6 @@ def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: class ReduceScatter(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: @@ -98,7 +97,6 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: class MoeDispatch(torch.autograd.Function): - @staticmethod def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) @@ -124,7 +122,6 @@ def backward(ctx, output_grad): class MoeCombine(torch.autograd.Function): - @staticmethod def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): assert logits.dtype == torch.float32 @@ -137,7 +134,7 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): # load moe kernel during runtime if not pre-built build_moe_if_not_prebuilt() - fp16_flag = (expert_tokens.dtype == torch.float16) + fp16_flag = expert_tokens.dtype == torch.float16 cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens @@ -155,8 +152,7 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index efda1f22252d..adad19d581ef 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -16,7 +16,7 @@ def load_moe_model(model: nn.Module, load_path: str): state_dict = torch.load(load_path) for prefix, module in model.named_modules(): - if prefix.endswith('.moe_layer.experts'): + if prefix.endswith(".moe_layer.experts"): # this module should be an Experts instance assert isinstance(module, MoeExperts) @@ -25,16 +25,16 @@ def load_moe_model(model: nn.Module, load_path: str): for i in range(num_local): expert_id = ep_rank * num_local + i for name, _ in module.experts[i].named_parameters(): - cur_key = f'{prefix}.experts.{i}.{name}' - param_key = f'{prefix}.experts.{expert_id}.{name}' + cur_key = f"{prefix}.experts.{i}.{name}" + param_key = f"{prefix}.experts.{expert_id}.{name}" load_param = state_dict[param_key] state_dict[cur_key] = load_param for name, _ in module.experts[0].named_parameters(): - pop_pre = f'{prefix}.experts.' - pop_suf = f'.{name}' + pop_pre = f"{prefix}.experts." + pop_suf = f".{name}" for i in range(num_local, module.num_total_experts): - pop_key = f'{pop_pre}{i}{pop_suf}' + pop_key = f"{pop_pre}{i}{pop_suf}" state_dict.pop(pop_key) model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 712d872bb921..4b2ecb241702 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -20,8 +20,10 @@ class MoeExperts(nn.Module): def __init__(self, comm_name: str, num_experts: int): super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." + assert comm_name in { + "all_to_all", + "all_gather", + }, "This kind of communication has not been implemented yet.\n Please use Experts build function." self.comm_name = comm_name self.num_total_experts = num_experts # Get the configuration of experts' deployment and parallel information from moe context @@ -50,7 +52,7 @@ def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args) # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) + param.__setattr__("moe_info", self.dist_info) def forward(self, inputs: torch.Tensor): # Split inputs for each expert @@ -65,7 +67,7 @@ def forward(self, inputs: torch.Tensor): output = torch.cat(expert_output, dim=1).contiguous() return output - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): assert keep_vars == False, "Only support keep_vars=False now" dp_rank = dist.get_rank(self.dist_info.dp_group) ep_rank = dist.get_rank(self.dist_info.ep_group) @@ -79,11 +81,11 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): example_submodule = subm if dp_rank == 0: - local_prefix = prefix + 'experts.' + local_prefix = prefix + "experts." buffer_module = deepcopy(example_submodule) for i in range(self.num_total_experts): source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + '.' + current_prefix = local_prefix + str(i) + "." comm_module = submodule_dict.get(i, buffer_module) for name, param in comm_module.named_parameters(): dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) @@ -94,8 +96,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ + """Use torch.bmm to speed up for multiple experts.""" def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_to_all", num_experts) @@ -119,10 +120,9 @@ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, d self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] + param.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, el, c, h] el = inputs.size(1) h = inputs.size(-1) @@ -137,7 +137,7 @@ def forward(self, inputs): # inputs [g, el, c, h] out_model = torch.baddbmm(self.b2, out_inter, self.w2) with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] + outputs = self.drop(out_model) # outputs [el, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() @@ -153,8 +153,7 @@ class TPExperts(MoeExperts): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divide by maximum expert parallel size" + assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size" p_ff = d_ff // MOE_CONTEXT.max_ep_size @@ -177,12 +176,11 @@ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, d self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] + self.w1.__setattr__("moe_info", self.dist_info) + self.w2.__setattr__("moe_info", self.dist_info) + self.b1.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, e, c, h] e = inputs.size(1) h = inputs.size(-1) @@ -196,8 +194,8 @@ def forward(self, inputs): # inputs [g, e, c, h] out_inter = self.drop(out_act) out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] + outputs = self.drop(out_model) # outputs [e, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 9293d3208f11..23d483e6a17a 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -89,8 +89,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: elif self.experts.comm_name == "all_gather": expert_output = self.tp_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n Please use Experts " "build function." + ) # expert_output [e, c, h] if self.use_kernel: expert_output = expert_output.reshape(-1, self.d_model) @@ -135,27 +136,29 @@ class MoeModule(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__(self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): + def __init__( + self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args, + ): super().__init__() noisy_func = None if noisy_policy is not None: - if noisy_policy == 'Jitter': + if noisy_policy == "Jitter": noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': + elif noisy_policy == "Gaussian": noisy_func = NormalNoiseGenerator(num_experts) else: raise NotImplementedError("Unsupported input noisy policy") @@ -167,18 +170,19 @@ def __init__(self, else: raise NotImplementedError("top_k > 2 is not supported yet") - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + self.moe_router = moe_router_cls( + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.use_residual = use_residual if use_residual: if residual_instance is not None: self.residual_module = residual_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" + assert expert_cls is not None, "Expert class can't be None when residual instance is not given" self.residual_module = expert_cls(**expert_args) with no_shard_zero_context(): @@ -187,14 +191,12 @@ def __init__(self, if expert_instance is not None: my_experts = expert_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" + assert expert_cls is not None, "Expert class can't be None when experts instance is not given" my_experts = Experts(expert_cls, num_experts, **expert_args) - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=my_experts) + self.moe_layer = MoeLayer( + dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts + ) def forward(self, inputs: torch.Tensor): moe_output, l_aux = self.moe_layer(inputs) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index c5b8390bf047..7ba83b2787a0 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -1,226 +1,235 @@ -import math -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum -from typing import Callable, Optional -from torch.distributed import ProcessGroup - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask +import math +from abc import ABC +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) + ).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = mask1 + mask2 # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4ca8bd703386..4f31dd5579dc 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,68 +1,71 @@ -import torch -import torch.nn.functional as F -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logits tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logits tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") +import torch +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device + +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/utils.py b/colossalai/nn/layer/utils.py index dc12ff8daa4e..ff9b5c8f2b5b 100644 --- a/colossalai/nn/layer/utils.py +++ b/colossalai/nn/layer/utils.py @@ -8,7 +8,6 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py index 34731ee901a0..783f12f8c7c4 100644 --- a/colossalai/nn/lr_scheduler/__init__.py +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -3,10 +3,21 @@ from .multistep import MultiStepLR, MultiStepWarmupLR from .onecycle import OneCycleLR from .poly import PolynomialLR, PolynomialWarmupLR -from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR +from .torch import ExponentialLR, LambdaLR, MultiplicativeLR, StepLR __all__ = [ - 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', - 'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR', - 'MultiplicativeLR', 'StepLR', 'ExponentialLR' + "CosineAnnealingLR", + "CosineAnnealingWarmupLR", + "FlatAnnealingLR", + "FlatAnnealingWarmupLR", + "LinearWarmupLR", + "MultiStepLR", + "MultiStepWarmupLR", + "OneCycleLR", + "PolynomialLR", + "PolynomialWarmupLR", + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "ExponentialLR", ] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index fb587e1a1341..a896d3acba6c 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -58,11 +58,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): - base_scheduler = _CosineAnnealingLR(optimizer, - total_steps - warmup_steps, - eta_min=eta_min, - last_epoch=last_epoch) + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0.0, last_epoch: int = -1): + base_scheduler = _CosineAnnealingLR( + optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch + ) super().__init__(optimizer, warmup_steps, base_scheduler) @@ -79,7 +78,7 @@ class FlatAnnealingLR(DelayerScheduler): def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int(total_steps * pct_start) anneal_steps = total_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) @@ -100,16 +99,18 @@ class FlatAnnealingWarmupLR(WarmupDelayerScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - pct_start: float = 0.72, - eta_min: int = 0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs, + ): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int((total_steps - warmup_steps) * pct_start) anneal_steps = total_steps - warmup_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index a73ff8ae37ac..ce7f126d6101 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -2,7 +2,6 @@ class _enable_get_lr_call: - def __init__(self, o): self.o = o @@ -28,18 +27,18 @@ class DelayerScheduler(_LRScheduler): def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -85,11 +84,11 @@ def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -130,9 +129,9 @@ class WarmupDelayerScheduler(_LRScheduler): def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") if warmup_epochs < 0: - raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') + raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") self.warmup_epochs = warmup_epochs self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler @@ -140,11 +139,11 @@ def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -155,7 +154,7 @@ def get_lr(self): self.after_scheduler.base_lrs = self.base_lrs # reset lr to base_lr for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): - group['lr'] = base_lr + group["lr"] = base_lr self.finished = True with _enable_get_lr_call(self.after_scheduler): return self.after_scheduler.get_lr() diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py index 21a865e4c12b..1251c261d51f 100644 --- a/colossalai/nn/lr_scheduler/linear.py +++ b/colossalai/nn/lr_scheduler/linear.py @@ -21,5 +21,7 @@ def get_lr(self): if self.last_epoch < self.warmup_steps: return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] else: - return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr - for lr in self.base_lrs] + return [ + (self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr + for lr in self.base_lrs + ] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index c428c911c94d..86589d74662d 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -20,13 +20,15 @@ class MultiStepLR(_MultiStepLR): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) @@ -44,16 +46,18 @@ class MultiStepWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): if len(milestones) == 0: - raise ValueError('milestones cannot be empty') + raise ValueError("milestones cannot be empty") milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py index 6835b3ee1cf2..a8e551526dbd 100644 --- a/colossalai/nn/lr_scheduler/onecycle.py +++ b/colossalai/nn/lr_scheduler/onecycle.py @@ -65,27 +65,31 @@ class OneCycleLR(_OneCycleLR): https://arxiv.org/abs/1708.07120 """ - def __init__(self, - optimizer, - total_steps: int, - pct_start=0.3, - anneal_strategy='cos', - cycle_momentum=True, - base_momentum=0.85, - max_momentum=0.95, - div_factor=25.0, - final_div_factor=10000.0, - last_epoch=-1, - **kwargs): - max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) - super().__init__(optimizer, - max_lrs, - total_steps=total_steps, - pct_start=pct_start, - anneal_strategy=anneal_strategy, - cycle_momentum=cycle_momentum, - base_momentum=base_momentum, - max_momentum=max_momentum, - div_factor=div_factor, - final_div_factor=final_div_factor, - last_epoch=last_epoch) + def __init__( + self, + optimizer, + total_steps: int, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1, + **kwargs, + ): + max_lrs = list(map(lambda group: group["lr"], optimizer.param_groups)) + super().__init__( + optimizer, + max_lrs, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + last_epoch=last_epoch, + ) diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py index 4f2249720ef6..4a3814461ea9 100644 --- a/colossalai/nn/lr_scheduler/poly.py +++ b/colossalai/nn/lr_scheduler/poly.py @@ -15,15 +15,11 @@ class PolynomialLR(_LRScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, optimizer, total_steps: int, end_lr: float = 0.0001, power: float = 1.0, last_epoch: int = -1, **kwargs + ): if end_lr < 0: - raise ValueError(f'end_lr must >= 0, got {end_lr}') + raise ValueError(f"end_lr must >= 0, got {end_lr}") self.total_steps = total_steps self.end_lr = end_lr self.power = power @@ -33,9 +29,11 @@ def get_lr(self): return self._get_closed_form_lr() def _get_closed_form_lr(self): - return [(base_lr - self.end_lr) * - ((1 - min(self.last_epoch, self.total_steps) / self.total_steps)**self.power) + self.end_lr - for base_lr in self.base_lrs] + return [ + (base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.total_steps) / self.total_steps) ** self.power) + + self.end_lr + for base_lr in self.base_lrs + ] class PolynomialWarmupLR(WarmupScheduler): @@ -51,13 +49,15 @@ class PolynomialWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs, + ): base_scheduler = PolynomialLR(optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index d839753d6c44..c4afc6128d43 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -3,7 +3,7 @@ ## Introduction Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI), -which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 7e310793f515..26f152da20d3 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -6,4 +6,4 @@ from .lamb import Lamb from .lars import Lars -__all__ = ['FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] +__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"] diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 9767fcb8b1e2..f35dc0200237 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -61,36 +61,39 @@ class CPUAdam(NVMeOptimizer): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None): - + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + ): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode cpu_adam = CPUAdamBuilder().load() self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - def torch_adam_update(self, - data, - grad, - exp_avg, - exp_avg_sq, - lr, - beta1, - beta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - use_adamw=False): + def torch_adam_update( + self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + use_adamw=False, + ): grad = grad.to(data.dtype) if weight_decay != 0: @@ -117,10 +120,9 @@ def step(self, closure=None, div_scale: float = -1): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -128,48 +130,81 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # FIXME(ver217): CPU adam kernel only supports fp32 states now assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - beta1, beta2 = group['betas'] + state["step"] += 1 + beta1, beta2 = group["betas"] - if target_device.type == 'cpu': + if target_device.type == "cpu": assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - elif target_device.type == 'cuda': + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + elif target_device.type == "cuda": assert div_scale == -1, "div_scale should remain default" - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] # adam on cuda - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: raise RuntimeError self._post_step() diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 3a05a34f52d2..fcdd3257d700 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -1,11 +1,11 @@ # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py -''' +""" Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. -''' +""" import torch from colossalai.utils import multi_tensor_applier @@ -51,37 +51,39 @@ class FusedAdam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - adamw_mode=True, - weight_decay=0., - amsgrad=False, - set_grad_none=True): - + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adamw_mode=True, + weight_decay=0.0, + amsgrad=False, + set_grad_none=True, + ): if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_adam = fused_optim.multi_tensor_adam else: - raise RuntimeError('FusedAdam requires cuda extensions') + raise RuntimeError("FusedAdam requires cuda extensions") def zero_grad(self, set_to_none=False): if set_to_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedAdam, self).zero_grad() @@ -97,51 +99,63 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no """ if any(p is not None for p in [grads, output_params, scale, grad_norms]): raise RuntimeError( - 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." ) loss = None if closure is not None: loss = closure() for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_l, p_l, m_l, v_l = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, please consider SparseAdam instead') + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: - raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.') + raise RuntimeError("FusedAdam only support fp16, fp32 and bf16.") g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) - - multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, - group['weight_decay'], div_scale) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) + + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index a2807d70f454..3e1d5a7ba539 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -49,41 +49,46 @@ class FusedLAMB(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.01, - amsgrad=False, - adam_w_mode=True, - grad_averaging=True, - set_grad_none=True, - max_grad_norm=1.0, - use_nvlamb=False): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + amsgrad=False, + adam_w_mode=True, + grad_averaging=True, + set_grad_none=True, + max_grad_norm=1.0, + use_nvlamb=False, + ): if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - defaults = dict(lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) + raise RuntimeError("FusedLAMB does not support the AMSGrad variant.") + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_lamb = fused_optim.multi_tensor_lamb else: - raise RuntimeError('FusedLAMB requires cuda extensions') + raise RuntimeError("FusedLAMB requires cuda extensions") self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none @@ -92,7 +97,7 @@ def __init__(self, def zero_grad(self): if self.set_grad_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedLAMB, self).zero_grad() @@ -111,7 +116,7 @@ def step(self, closure=None): # create separate grad lists for fp32 and fp16 params g_all_32, g_all_16 = [], [] for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.dtype == torch.float32: @@ -119,7 +124,7 @@ def step(self, closure=None): elif p.dtype == torch.float16: g_all_16.append(p.grad.data) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') + raise RuntimeError("FusedLAMB only support fp16 and fp32.") device = self.param_groups[0]["params"][0].device g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) @@ -130,63 +135,91 @@ def step(self, closure=None): g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, - [[g_norm_32, g_norm_16]], False)[0] - max_grad_norm = self.defaults['max_grad_norm'] + global_grad_norm = multi_tensor_applier( + self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False + )[0] + max_grad_norm = self.defaults["max_grad_norm"] for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedLAMB does not support sparse gradients, please consider SparseAdam instead') + "FusedLAMB does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) if p.dtype == torch.float16: g_16.append(p.grad.data) p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - if (len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) - if (len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) + raise RuntimeError("FusedLAMB only support fp16 and fp32.") + + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) return loss diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 59a93a8be9c7..95a6354208a8 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -54,14 +54,9 @@ class FusedSGD(Optimizer): The Nesterov version is analogously modified. """ - def __init__(self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, - wd_after_momentum=False): + def __init__( + self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False + ): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -78,20 +73,21 @@ def __init__(self, if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_sgd = fused_optim.multi_tensor_sgd else: - raise RuntimeError('FusedSGD requires cuda extensions') + raise RuntimeError("FusedSGD requires cuda extensions") def __setstate__(self, state): super(FusedSGD, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) def get_momentums(self, params): momentums = [] @@ -101,13 +97,13 @@ def get_momentums(self, params): # torch.optim.SGD initializes momentum in the main loop, we have # to do it here, and track whether or not we've done so, so that # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: + if "momentum_buffer" not in param_state: first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p) + buf = param_state["momentum_buffer"] = torch.zeros_like(p) momentums.append(buf) else: first_run = False - momentums.append(param_state['momentum_buffer']) + momentums.append(param_state["momentum_buffer"]) return momentums, first_run def step(self, closure=None): @@ -122,10 +118,10 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] # For each group, there are 3 possible combinations we need to consider: # grad_type, param_to_update_type, momentum_type @@ -133,15 +129,26 @@ def step(self, closure=None): # 2. fp32, fp32, fp32 # 3. fp16, fp32, fp32 g_l, p_l = [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: - raise RuntimeError('FusedSGD does not support sparse gradients') + raise RuntimeError("FusedSGD does not support sparse gradients") g_l.append(p.grad) p_l.append(p) m_l, first_run = self.get_momentums(p_l) - multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, - momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) + multi_tensor_applier( + self.multi_tensor_sgd, + self._dummy_overflow_buf, + [g_l, p_l, m_l], + weight_decay, + momentum, + dampening, + group["lr"], + nesterov, + first_run, + self.wd_after_momentum, + 1.0, + ) return loss diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index e08df410effe..32fc6136c4e6 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,7 +1,6 @@ from typing import Any, Optional import torch -from torch.optim import Adam from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.utils import multi_tensor_applier @@ -61,20 +60,30 @@ class HybridAdam(CPUAdam): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None, - **defaults: Any): - - super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction, - nvme_offload_dir) + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + **defaults: Any, + ): + super().__init__( + model_params, + lr, + bias_correction, + betas, + eps, + weight_decay, + adamw_mode, + nvme_offload_fraction, + nvme_offload_dir, + ) fused_optim = FusedOptimBuilder().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -86,12 +95,11 @@ def step(self, closure=None, div_scale: float = -1): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): g_l, p_l, m_l, v_l = [], [], [], [] group_step = 0 - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -99,54 +107,87 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # FIXME(ver217): CPU adam kernel only supports fp32 states now assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - group_step = state['step'] - beta1, beta2 = group['betas'] + state["step"] += 1 + group_step = state["step"] + beta1, beta2 = group["betas"] - if target_device.type == 'cpu': - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') + if target_device.type == "cpu": + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - - elif target_device.type == 'cuda': - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + + elif target_device.type == "cuda": + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" # record the state by group and update at once g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) else: raise RuntimeError if len(g_l) > 0: adamw_mode = 1 if self.adamw_mode else 0 - bias_correction = 1 if group['bias_correction'] else 0 - multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, - bias_correction, group['weight_decay'], div_scale) + bias_correction = 1 if group["bias_correction"] else 0 + multi_tensor_applier( + self.gpu_adam_op, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + group["betas"][0], + group["betas"][1], + group["eps"], + group_step, + adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) self._post_step() return loss diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index d5de267f73ee..0d742487f473 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -51,27 +51,27 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.') + raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instead.") state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 # Decay the first and second moment running average coefficient # m_t @@ -84,22 +84,22 @@ def step(self, closure=None): # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. # * math.sqrt(bias_correction2) / bias_correction1 - step_size = group['lr'] + step_size = group["lr"] weight_norm = p.data.pow(2).sum().sqrt() - adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) - if group['weight_decay'] != 0: - adam_step.add_(p.data, alpha=group['weight_decay']) + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm - state['weight_norm'] = weight_norm - state['adam_norm'] = adam_norm - state['trust_ratio'] = trust_ratio + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio if self.adam: trust_ratio = 1 diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py index 58393fdae4bf..b117c00846d1 100644 --- a/colossalai/nn/optimizer/lars.py +++ b/colossalai/nn/optimizer/lars.py @@ -19,13 +19,9 @@ class Lars(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) """ - def __init__(self, - params: Iterable[torch.nn.Parameter], - lr=1e-3, - momentum=0, - eeta=1e-3, - weight_decay=0, - epsilon=0.0) -> None: + def __init__( + self, params: Iterable[torch.nn.Parameter], lr=1e-3, momentum=0, eeta=1e-3, weight_decay=0, epsilon=0.0 + ) -> None: if not isinstance(lr, float) or lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -54,14 +50,14 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - eeta = group['eeta'] - lr = group['lr'] - lars = group['lars'] - eps = group['epsilon'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + eeta = group["eeta"] + lr = group["lr"] + lars = group["lars"] + eps = group["epsilon"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue decayed_grad = p.grad @@ -69,9 +65,11 @@ def step(self, closure=None): if lars: w_norm = torch.norm(p) g_norm = torch.norm(p.grad) - trust_ratio = torch.where(w_norm > 0 and g_norm > 0, - eeta * w_norm / (g_norm + weight_decay * w_norm + eps), - torch.ones_like(w_norm)) + trust_ratio = torch.where( + w_norm > 0 and g_norm > 0, + eeta * w_norm / (g_norm + weight_decay * w_norm + eps), + torch.ones_like(w_norm), + ) trust_ratio.clamp_(0.0, 50) scaled_lr *= trust_ratio.item() if weight_decay != 0: @@ -80,10 +78,10 @@ def step(self, closure=None): if momentum != 0: param_state = self.state[p] - if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach() + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(decayed_grad).detach() else: - buf = param_state['momentum_buffer'] + buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(decayed_grad) decayed_grad = buf diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py index fb3a4d87be60..fd02bfb683e1 100644 --- a/colossalai/nn/optimizer/nvme_optimizer.py +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -19,13 +19,11 @@ class NVMeOptimizer(torch.optim.Optimizer): Raises: ImportError: Raise if ``tensornvme`` is not installed. - """ + """ - def __init__(self, - params, - defaults: dict, - nvme_offload_fraction: float = 0.0, - offload_dir: Optional[str] = None) -> None: + def __init__( + self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None + ) -> None: assert 0.0 <= nvme_offload_fraction <= 1.0 super().__init__(params, defaults) self.nvme_offload_fraction = float(nvme_offload_fraction) @@ -34,9 +32,9 @@ def __init__(self, from tensornvme import DiskOffloader from tensornvme._C import get_backends except ModuleNotFoundError: - raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer') + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") self.offload_dir = offload_dir or tempfile.mkdtemp() - backend = 'uring' if 'uring' in get_backends() else 'aio' + backend = "uring" if "uring" in get_backends() else "aio" self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend) else: self.offload_dir = None @@ -53,13 +51,17 @@ def __init__(self, def _get_numel(self) -> int: numel = 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: numel += p.storage().size() return numel def _post_state_init(self, param: Parameter) -> None: numel = param.storage().size() - if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel: + if ( + self.offloader is not None + and param.device.type == "cpu" + and numel + self.offloaded_numel <= self.can_offload_numel + ): self.is_on_nvme[param] = True self.offloaded_numel += numel else: @@ -70,11 +72,11 @@ def _setup_prefetch_params(self) -> List[Parameter]: return assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if len(self.state[p]) > 0 and self.is_on_nvme[p]: - assert p.device.type == 'cpu' + assert p.device.type == "cpu" self.param_to_prefetch_idx[p] = len(self.prefetch_params) self.prefetch_params.append(p) @@ -156,7 +158,7 @@ def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) def __del__(self) -> None: - if getattr(self, 'offloader', None) is not None: + if getattr(self, "offloader", None) is not None: del self.offloader if os.path.exists(self.offload_dir): try: diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index e88a1f00a1b7..4754212c1914 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -3,9 +3,9 @@ from .stage_manager import PipelineStageManager __all__ = [ - 'PipelineSchedule', - 'OneForwardOneBackwardSchedule', - 'InterleavedSchedule', - 'PipelineP2PCommunication', - 'PipelineStageManager', + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", + "PipelineP2PCommunication", + "PipelineStageManager", ] diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index aed85cf91512..c69bbe6e8521 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -29,11 +29,11 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - Any: object after unpickled """ buf = tensor.numpy().tobytes()[:tensor_size] - if b'cuda' in buf: + if b"cuda" in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() # There might be more than one output tensors during forward - for cuda_str in re.finditer(b'cuda', buf_array): + for cuda_str in re.finditer(b"cuda", buf_array): pos = cuda_str.start() buf_array[pos + 5] = 48 + device_index buf = bytes(buf_array) @@ -45,10 +45,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle -def _broadcast_object_list(object_list: List[Any], - src: int, - group: ProcessGroup, - device: Optional[Union[torch.device, str, int]] = None): +def _broadcast_object_list( + object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None +): """This is a modified version of the broadcast_object_list in torch.distribution The only difference is that object will be move to correct device after unpickled. If local_rank = src, then object list will be sent to rank src. Otherwise, object list will @@ -99,8 +98,8 @@ def _broadcast_object_list(object_list: List[Any], if my_rank == src: object_tensor = torch.cat(tensor_list) else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, ) @@ -114,7 +113,7 @@ def _broadcast_object_list(object_list: List[Any], if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -123,8 +122,10 @@ def _broadcast_object_list(object_list: List[Any], unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) # unconsistence in device - if isinstance(unpickle_object, - torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object @@ -160,7 +161,6 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: class PipelineP2PCommunication: - def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -192,8 +192,9 @@ def recv_backward(self, next_rank: int = None) -> Any: if next_rank is None: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + output_tensor_grad = _recv_object( + next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) + ) return output_tensor_grad diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 07c0f5927060..6845dc23753b 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -3,7 +3,7 @@ from .one_f_one_b import OneForwardOneBackwardSchedule __all__ = [ - 'PipelineSchedule', - 'OneForwardOneBackwardSchedule', - 'InterleavedSchedule', + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", ] diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 583558551b3c..271b3238f5c4 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -4,24 +4,15 @@ import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import ( - SUPPORTED_NODES, - LeafSpec, - TreeSpec, - _is_leaf, - _register_pytree_node, - tree_flatten, - tree_map, - tree_unflatten, -) +from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten # this register are for torch under version 1.13.1, maybe removed in the future -def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: +def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]: return list(d.values()), list(d.keys()) -def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': +def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]": return OrderedDict((key, value) for key, value in zip(context, values)) @@ -45,7 +36,7 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: # Recursively flatten the children result: List[Any] = [] - children_specs: List['TreeSpec'] = [] + children_specs: List["TreeSpec"] = [] for child in child_pytrees: flat, child_spec = tree_flatten_hf(child) result += flat @@ -87,7 +78,7 @@ def get_batch_size(batch: Any) -> int: for data in data_list: if isinstance(data, torch.Tensor): return data.size(0) - raise RuntimeError('No tensor found in the batch') + raise RuntimeError("No tensor found in the batch") def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: @@ -104,7 +95,7 @@ def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: def _get_tensor_slice(x: Any): if isinstance(x, torch.Tensor): - return x[start:start + micro_batch_size] + return x[start : start + micro_batch_size] return x return tree_map(_get_tensor_slice, batch) @@ -175,7 +166,7 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs merged_data.append(None) else: merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py index b0fa6e6ad2b8..1bce297862c8 100644 --- a/colossalai/pipeline/schedule/base.py +++ b/colossalai/pipeline/schedule/base.py @@ -8,17 +8,18 @@ class PipelineSchedule: - def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager - def forward_backward_step(self, - model: Module, - data_iter: Iterable, - criterion: Callable[[Any, Any], Tensor], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[[Any, Any], Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Forward and backward step for pipeline training. Args: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 6fdb09be5f32..780437155c61 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -16,11 +16,11 @@ class InterleavedSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: self.num_model_chunks = num_model_chunks - assert num_microbatches % self.num_model_chunks == 0, \ - "Number of microbatches should be an integer multiple of number of model chunks" + assert ( + num_microbatches % self.num_model_chunks == 0 + ), "Number of microbatches should be an integer multiple of number of model chunks" super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches @@ -42,8 +42,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches def load_micro_batch(self, model_chunk_id: int) -> Any: @@ -72,7 +71,7 @@ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not forward: - model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id def is_first_stage(self, model_chunk_id: int) -> bool: @@ -161,13 +160,15 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None if not self.is_first_stage(model_chunk_id): self.comm.send_backward(input_object, prev_rank) - def forward_step(self, - model_chunk: Module, - model_chunk_id: int, - input_obj: Optional[dict], - criterion: Callable, - accum_loss: Optional[torch.Tensor] = None, - outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + def forward_step( + self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: model (Module): Model Chunk to be run @@ -195,8 +196,13 @@ def forward_step(self, else: return output_obj - def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], - output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: """Backward one step of the pipeline Args: @@ -235,13 +241,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step(self, - model_chunk: Module, - data_iter: Iterable, - criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model_chunk: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: @@ -321,7 +329,7 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -369,4 +377,4 @@ def forward_backward_step(self, if outputs is not None: outputs = merge_batch(outputs) - return {'loss': accum_loss, 'outputs': outputs} + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index fbd0f9f0d4c0..4eaf135fd5db 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -25,11 +25,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): - - def __init__(self, - stage_manager: PipelineStageManager, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None) -> None: + def __init__( + self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + ) -> None: """1F1B pipeline schedule. Args: @@ -38,8 +39,9 @@ def __init__(self, microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. """ super().__init__(stage_manager) - assert num_microbatches is not None or microbatch_size is not None, \ - "Either num_microbatches or microbatch_size should be provided" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size @@ -62,12 +64,12 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 if not self._use_microbatch_size: - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert ( + self.batch_size % self.num_microbatches == 0 + ), "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches else: - assert self.batch_size % self.microbatch_size == 0, \ - "Batch size should divided by the microbatch size" + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatches = self.batch_size // self.microbatch_size def load_micro_batch(self) -> Any: @@ -136,12 +138,14 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank) - def forward_step(self, - model: Module, - input_obj: Optional[dict], - criterion: Callable, - accum_loss: Optional[torch.Tensor] = None, - outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + def forward_step( + self, + model: Module, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: @@ -159,7 +163,6 @@ def forward_step(self, # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): - loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) @@ -169,8 +172,13 @@ def forward_step(self, else: return output_obj - def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], - output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: """Backward one step of the pipeline Args: @@ -208,13 +216,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step(self, - model: Module, - data_iter: Iterable, - criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Args: @@ -273,7 +283,7 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -316,5 +326,5 @@ def forward_backward_step(self, if outputs is not None: if isinstance(model, ModelWrapper): model = model.unwrap() - outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) - return {'loss': accum_loss, 'outputs': outputs} + outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 6ba7dc629958..b79867a2c651 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Dict, List, Optional, Tuple import torch.distributed as dist @@ -28,13 +27,11 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + prev_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1 :] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode="wrap") # the next rank of the last rank is rank0 - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') + next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") # init p2p process groups stages = list(range(self.num_stages)) diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py index c553080de0a0..96d6cea21075 100644 --- a/colossalai/shardformer/_utils.py +++ b/colossalai/shardformer/_utils.py @@ -13,14 +13,14 @@ def get_obj_list_element(obj, attr: str): attr (str): The suffix of the attribute to get """ - re_pattern = r'\[\d+\]' + re_pattern = r"\[\d+\]" prog = re.compile(re_pattern) result = prog.search(attr) if result: matched_brackets = result.group() - matched_index = matched_brackets.replace('[', '') - matched_index = matched_index.replace(']', '') - attr_ = attr.replace(matched_brackets, '') + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") container_obj = getattr(obj, attr_) obj = container_obj[int(matched_index)] else: @@ -38,14 +38,14 @@ def set_obj_list_element(obj, attr: str, value): obj (object): The object to set attr (str): the string including a list index like `layers[0]` """ - re_pattern = r'\[\d+\]' + re_pattern = r"\[\d+\]" prog = re.compile(re_pattern) result = prog.search(attr) if result: matched_brackets = result.group() - matched_index = matched_brackets.replace('[', '') - matched_index = matched_index.replace(']', '') - attr_ = attr.replace(matched_brackets, '') + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") container_obj = getattr(obj, attr_) container_obj[int(matched_index)] = value else: @@ -60,7 +60,7 @@ def hasattr_(obj, attr: str): obj (object): The object to check attr (str): The multi level attr to check """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs: try: obj = get_obj_list_element(obj, a) @@ -80,7 +80,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): ignore (bool): Whether to ignore when the attr doesn't exist """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs[:-1]: try: obj = get_obj_list_element(obj, a) @@ -101,7 +101,7 @@ def getattr_(obj, attr: str, ignore: bool = False): ignore (bool): Whether to ignore when the attr doesn't exist """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs: try: obj = get_obj_list_element(obj, a) diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index 81be2017855c..b03e6201dce8 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -7,7 +7,7 @@ import torch.distributed as dist from data import GLUEDataBuilder from torch import nn -from torch.optim import Adam, AdamW, Optimizer +from torch.optim import Adam, Optimizer from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from tqdm import tqdm @@ -15,12 +15,10 @@ import colossalai from colossalai.cluster import DistCoordinator -from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import ShardConfig, ShardFormer def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): if isinstance(t, torch.Tensor): return t.to(device) @@ -34,10 +32,12 @@ def train(args): coordinator = DistCoordinator() # prepare for data and dataset - data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, - task_name=args.task, - train_batch_size=args.batch_size, - eval_batch_size=args.batch_size) + data_builder = GLUEDataBuilder( + model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size, + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -49,10 +49,10 @@ def train(args): # if multiple GPUs, shard the model if dist.get_world_size() > 1: - tp_group = dist.new_group(backend='nccl') - shard_config = ShardConfig(tensor_parallel_process_group=tp_group, - enable_tensor_parallelism=True, - enable_all_optimization=True) + tp_group = dist.new_group(backend="nccl") + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, enable_tensor_parallelism=True, enable_all_optimization=True + ) shard_former = ShardFormer(shard_config=shard_config) model, _ = shard_former.optimize(model) @@ -64,21 +64,40 @@ def train(args): num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), num_training_steps=max_steps, ) - fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, - coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + fit( + model, + optim, + lr_scheduler, + train_dataloader, + args.max_epochs, + args.accumulation_steps, + args.batch_size, + coordinator, + ) + results = evaluate_model( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' - - -def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, - coordinator): - step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), - desc=f'steps', - disable=not coordinator.is_master()) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit( + model: nn.Module, + optimizer: Optimizer, + scheduler, + train_dataloader, + max_epochs, + accumulation_steps, + batch_size, + coordinator, +): + step_bar = tqdm( + range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f"steps", + disable=not coordinator.is_master(), + ) total_loss = 0 for epoch in range(max_epochs): model.train() @@ -93,19 +112,23 @@ def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max optimizer.step() scheduler.step() optimizer.zero_grad() - step_bar.set_postfix({ - 'epoch': epoch, - 'loss': total_loss / batch_size, - 'lr': scheduler.get_last_lr()[0] - }) + step_bar.set_postfix( + {"epoch": epoch, "loss": total_loss / batch_size, "lr": scheduler.get_last_lr()[0]} + ) total_loss = 0 step_bar.update() # evaluate @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, - task_name: str, eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -127,7 +150,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() if coordinator.is_master(): - results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + results["loss"] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) return results if isinstance(test_dataloader, DataLoader): @@ -137,21 +160,21 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('--model', type=str, default="bert") - parser.add_argument('--pretrain', type=str, default="bert-base-uncased") - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lr', type=float, default=2.4e-5) - parser.add_argument('--fused_layernorm', type=bool, default=False) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--warmup_fraction', type=float, default=0.03) - parser.add_argument('--target_f1', type=float, default=None) + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument("--model", type=str, default="bert") + parser.add_argument("--pretrain", type=str, default="bert-base-uncased") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lr", type=float, default=2.4e-5) + parser.add_argument("--fused_layernorm", type=bool, default=False) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--warmup_fraction", type=float, default=0.03) + parser.add_argument("--target_f1", type=float, default=None) args = parser.parse_args() train(args) diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py index 6296d4be4eb0..ddf44a874659 100644 --- a/colossalai/shardformer/examples/data.py +++ b/colossalai/shardformer/examples/data.py @@ -6,7 +6,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -86,14 +85,12 @@ def prepare_data(self): def train_dataloader(self): if self.plugin == None: - return self.native_prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.native_prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if self.plugin == None: @@ -118,7 +115,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -126,10 +122,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] @@ -137,10 +132,6 @@ def convert_to_features(self, example_batch): return features def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): - - return DataLoader(dataset, - batch_size=batch_size, - sampler=None, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory) + return DataLoader( + dataset, batch_size=batch_size, sampler=None, shuffle=shuffle, drop_last=drop_last, pin_memory=pin_memory + ) diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 2f186709d946..81215dcdf5d4 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -20,35 +20,35 @@ def data_gen_for_sequence_classification(batch_size, seq_length): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen(batch_size, seq_length) - data['labels'] = torch.ones((batch_size), dtype=torch.long) + data["labels"] = torch.ones((batch_size), dtype=torch.long) return data -MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128, - num_labels=16, - pad_token_id=2) +MODEL_CONFIG = transformers.LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + pad_token_id=2, +) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) # vary seq length for fixed head and batch=4 configs = [ - triton.testing.Benchmark(x_names=['N_CTX'], - x_vals=[2**i for i in range(8, 13)], - line_arg='provider', - line_vals=['org_model', 'shard_model'], - line_names=['org_model', 'shard_model'], - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'lama_for_sequence_classification-batch-{BATCH}', - args={ - 'BATCH': BATCH, - 'dtype': torch.float16, - 'model_func': model_func - }) + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(8, 13)], + line_arg="provider", + line_vals=["org_model", "shard_model"], + line_names=["org_model", "shard_model"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"lama_for_sequence_classification-batch-{BATCH}", + args={"BATCH": BATCH, "dtype": torch.float16, "model_func": model_func}, + ) ] @@ -85,4 +85,4 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d # torchrun --standalone --nproc_per_node=2 performance_benchmark.py if __name__ == "__main__": colossalai.launch_from_torch({}) - bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0) + bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index c4586d18b90c..a134a2cbd21c 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -7,7 +7,17 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col', 'ParallelModule' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "DropoutForParallelInput", + "DropoutForReplicatedInput", + "cross_entropy_1d", + "FusedLayerNorm", + "FusedRMSNorm", + "FusedLinear1D_Col", + "ParallelModule", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 45b305733813..5ec48096183b 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,5 +1,3 @@ -from typing import Any - import torch import torch.distributed as dist import torch.nn.functional as F @@ -22,7 +20,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability - """ + """ @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -31,8 +29,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -40,11 +39,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None @@ -195,8 +192,9 @@ def backward(ctx, grad_output): input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, - device=input_parallel.device).contiguous() + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # reduce-scatter scheduled first and have GPU resources allocated @@ -260,8 +258,9 @@ def forward(ctx, input_, process_group, dim): # do reduce-scatter new_shape = list(input_.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) @@ -329,8 +328,9 @@ def backward(ctx, grad_output): input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, - device=input_parallel.device).contiguous() + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # reduce-scatter scheduled first and have GPU resources allocated @@ -473,9 +473,10 @@ def _split(input_, dim=-1, process_group=None): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = dist.get_rank(process_group) @@ -502,7 +503,7 @@ def _gather(input_, dim=-1, process_group=None): def _reduce_scatter(input_, dim=1, process_group=None): - """ Do reduce-scatter operation. + """Do reduce-scatter operation. Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -515,8 +516,9 @@ def _reduce_scatter(input_, dim=1, process_group=None): # reduce-scatter new_shape = list(input_.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // world_size output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) dist.reduce_scatter(output, input_, group=process_group) @@ -532,20 +534,24 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) -def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, - overlap): - return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim, overlap) +def linear_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _LinearWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, - overlap): - return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim, overlap) +def matmul_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _MatmulWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) def gather_forward_split_backward(input_, dim, process_group): diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 2625fe97889a..8771913ee62f 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -7,7 +7,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput'] +__all__ = ["DropoutForParallelInput", "DropoutForReplicatedInput"] class DropoutForParallelInput(ParallelModule, nn.Dropout): @@ -31,8 +31,9 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) @staticmethod - def from_native_module(module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput": + def from_native_module( + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForParallelInput": """ Create a DropoutForParallelInput layer from a native dropout layer. """ @@ -68,8 +69,8 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process @staticmethod def from_native_module( - module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput": + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForReplicatedInput": """ Create a Dropout1D layer from a native dropout layer. """ diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 847ca175ad57..62163cb009aa 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -24,7 +24,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Embedding1D', 'VocabParallelEmbedding1D'] +__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] class Embedding1D(ParallelModule): @@ -57,18 +57,20 @@ class Embedding1D(ParallelModule): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight: Optional[nn.Parameter] = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings @@ -86,7 +88,7 @@ def __init__(self, # Parameters. if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -100,10 +102,9 @@ def __init__(self, self.reset_parameters(weight_initializer) @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]] = None, *args, **kwargs + ) -> "Embedding1D": r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ @@ -123,19 +124,21 @@ def from_native_module(module: nn.Embedding, if sparse: raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - weight=module.weight, - *args, - **kwargs) + embedding = Embedding1D( + num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + weight=module.weight, + *args, + **kwargs, + ) return embedding @@ -188,17 +191,19 @@ class VocabParallelEmbedding1D(ParallelModule): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight: Optional[nn.Parameter] = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -223,7 +228,7 @@ def __init__(self, # parameter if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -236,8 +241,9 @@ def __init__(self, self.reset_parameters(weight_initializer) @staticmethod - def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native pytorch embedding module to a parallel module. """ @@ -250,19 +256,20 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, # ensure only one process group is used if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] # create the parallel module - vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - device=device, - process_group=process_group, - weight=module.weight, - *args, - **kwargs) + vocab_embedding_1d = VocabParallelEmbedding1D( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + weight=module.weight, + *args, + **kwargs, + ) return vocab_embedding_1d @@ -273,8 +280,11 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) @@ -294,11 +304,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_forward(output_parallel, self.process_group) return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 111d51b3f8d8..cf2003877d3c 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -33,7 +33,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Linear1D_Col', 'Linear1D_Row'] +__all__ = ["Linear1D_Col", "Linear1D_Row"] class Linear1D_Col(ParallelModule): @@ -65,22 +65,24 @@ class Linear1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - seq_parallel: bool = False, - seq_parallel_dim: int = 1, - overlap: torch.cuda.Stream = None, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -95,7 +97,7 @@ def __init__(self, self.process_group = process_group if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -103,13 +105,13 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -135,8 +137,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -149,8 +152,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -159,17 +161,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if out_features % tp_size != 0: raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = Linear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -181,9 +186,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. input_parallel = input_ @@ -191,9 +198,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: - output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, - self.seq_parallel_dim, self.overlap) + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + ) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -210,7 +217,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class Linear1D_Row(ParallelModule): - r""" Linear layer with row parallelism + r"""Linear layer with row parallelism Args: in_features (int): size of each input sample. @@ -231,22 +238,24 @@ class Linear1D_Row(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - seq_parallel: bool = False, - seq_parallel_dim: int = 1, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -262,7 +271,7 @@ def __init__(self, self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -270,14 +279,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -304,8 +313,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -318,8 +328,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -328,17 +337,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = Linear1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -366,14 +378,18 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions + ) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: @@ -384,9 +400,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -395,8 +411,9 @@ def forward(self, input_: Tensor) -> Tensor: else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, - self.seq_parallel_dim) + output = linear_reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim + ) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 7e3f6926b6d4..848e4a3a1f7d 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,7 +3,7 @@ from torch.autograd import Function from torch.distributed import ProcessGroup -__all__ = ['DistCrossEntropy', 'cross_entropy_1d'] +__all__ = ["DistCrossEntropy", "cross_entropy_1d"] class DistCrossEntropy(Function): @@ -61,8 +61,9 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero - pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), - masked_target_1d] + pred_logits_1d = logits_2d[ + torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d + ] pred_logits_1d = pred_logits_1d.clone().contiguous() pred_logits = pred_logits_1d.view_as(target) pred_logits[mask] = 0.0 @@ -102,8 +103,7 @@ def backward(ctx, grad_output): return grad_logits, None, None -def cross_entropy_1d(vocab_logits: torch.Tensor, - labels: torch.Tensor, - ignore_index: int = -100, - process_group: ProcessGroup = None) -> torch.Tensor: +def cross_entropy_1d( + vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None +) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 0aea295664a7..19b973be8679 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,28 +1,49 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch import torch.nn as nn from colossalai.lazy import LazyInitContext -__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] +__all__ = ["FusedLayerNorm", "FusedRMSNorm"] FAST_LAYERNORM_SUPPORTED_SIZE = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, - 25600, 30720, 32768, 40960, 49152, 65536 + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, ] -class FusedLayerNorm(): +class FusedLayerNorm: r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( - 'FusedLayerNorm is not implemented as a physical class. ' - 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." ) @staticmethod @@ -32,10 +53,11 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: """ # check if apex is installed try: - import apex + pass except ImportError: raise ImportError( - 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" + ) LazyInitContext.materialize(module) # get the attributes of the module @@ -57,23 +79,24 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: else: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, - elementwise_affine=elementwise_affine).to(dtype).to(device) + layernorm = ( + ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) + ) layernorm.weight = module.weight layernorm.bias = module.bias return layernorm -class FusedRMSNorm(): +class FusedRMSNorm: """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( - 'FusedRMSNorm is not implemented as a physical class. ' - 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' + "FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." ) @staticmethod @@ -82,7 +105,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm except ImportError: raise ImportError( - 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" ) LazyInitContext.materialize(module) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 4f391920e29b..6c0d83cc7a20 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -19,18 +19,16 @@ is_customized_distributed_tensor, is_distributed_tensor, sharded_tensor_to_param, - to_global, - to_global_for_customized_distributed_tensor, ) -__all__ = ['ParallelModule'] +__all__ = ["ParallelModule"] class ParallelModule(nn.Module, ABC): - @abstractmethod - def from_native_module(module: nn.Module, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "ParallelModule": """ Convert a native PyTorch module to a parallelized module. @@ -40,7 +38,6 @@ def from_native_module(module: nn.Module, If this is a list, the process group at the ith index of the list will correspond to the process group in the ith axis of the device mesh. Defaults to None, which means the global process group. """ - pass def _save_to_state_dict(self, destination, prefix, keep_vars): r"""Saves module state to `destination` dictionary, containing a state @@ -66,8 +63,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: destination[extra_state_key] = self.get_extra_state() - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -112,9 +110,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if key in state_dict: input_param = state_dict[key] if not torch.overrides.is_tensor_like(input_param): - error_msgs.append('While copying the parameter named "{}", ' - 'expected torch.Tensor or Tensor-like object from checkpoint but ' - 'received {}'.format(key, type(input_param))) + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) continue if is_distributed_tensor(param): @@ -136,19 +136,22 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(key, input_param.shape, param.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) continue try: with torch.no_grad(): param.copy_(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), - ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(key) @@ -164,7 +167,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 5ce77805f9b8..12476d050600 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -36,17 +36,16 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row'] +__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] # ==================================== # For GPT Only # ==================================== -def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, - n_fused: int, - process_group: ProcessGroup, - is_transposed: bool = False): +def split_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): """ The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. @@ -85,10 +84,9 @@ def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, return weight_of_current_rank -def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor, - n_fused: int, - process_group: ProcessGroup, - is_transposed: bool = False): +def gather_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): """ The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. @@ -167,23 +165,25 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - async_communication: bool = False, - gather_output: bool = False, - seq_parallel: bool = False, - overlap: bool = False, - skip_bias_add: bool = False, - n_fused: int = 3, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -199,7 +199,7 @@ def __init__(self, self.async_communication = async_communication if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -207,14 +207,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -249,8 +249,9 @@ def gather_fn(tensor): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -268,8 +269,7 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -278,17 +278,20 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis if out_features % tp_size != 0: raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -300,22 +303,26 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Matrix multiply. bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + output_parallel = matmul_with_async_comm( + input_parallel, self.weight, bias, self.process_group, self.async_communication + ) if self.gather_output: # All-gather across the partitions. @@ -330,7 +337,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class GPT2FusedLinearConv1D_Row(ParallelModule): - r""" Linear layer with row parallelism. + r"""Linear layer with row parallelism. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: @@ -351,21 +358,23 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - seq_parallel: bool = False, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -380,7 +389,7 @@ def __init__(self, self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -391,14 +400,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -424,8 +433,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -438,8 +448,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -448,17 +457,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -485,14 +497,18 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[0]) + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] + ) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions) + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions + ) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: @@ -503,9 +519,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -559,21 +575,23 @@ class FusedLinear1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - async_communication: bool = False, - gather_output: bool = False, - skip_bias_add: bool = False, - n_fused: int = 3, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters self.in_features = in_features @@ -586,7 +604,7 @@ def __init__(self, self.async_communication = async_communication if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -594,14 +612,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -636,8 +654,9 @@ def gather_fn(tensor): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, - *args, **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs + ) -> ParallelModule: r""" Convert a fused `torch.nn.linear` layer to a parallelized linear layer. @@ -654,19 +673,20 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] - linear_1d = FusedLinear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + linear_1d = FusedLinear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) # # TODO: copy the sharded weights # with torch.no_grad(): @@ -693,9 +713,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. # input_parallel = reduce_backward(input_, self.process_group) input_parallel = input_ diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 577bef076a7e..c3d8501cdeae 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_global_rank class Randomizer: @@ -172,10 +171,9 @@ def synchronize_index(process_group: ProcessGroup = None): Randomizer._INDEX = index_tensor.item() -def create_randomizer_with_offset(seed: int, - process_group: ProcessGroup = None, - offset_by_rank: bool = True, - offset_by_index: bool = True): +def create_randomizer_with_offset( + seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True +): """ Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. @@ -197,9 +195,11 @@ def create_randomizer_with_offset(seed: int, if offset_by_index: # check if the randomizer index is synchronized is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) - assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes." - "This is not allowed when we want to create a randomizer with offset by index." - "Please call Randomizer.synchronize_index() first.") + assert is_synchronized, ( + "We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first." + ) base_seed += Randomizer.index() Randomizer.increment_index() diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 30855a622adb..7411e1d0ec46 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -34,10 +34,10 @@ class BertPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Bert models under pipeline setting. - ''' + """ @staticmethod def bert_model_forward( @@ -56,36 +56,37 @@ def bert_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -118,13 +119,13 @@ def bert_model_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # past_key_values_length @@ -173,7 +174,8 @@ def bert_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False next_decoder_cache = () if use_cache else None @@ -184,12 +186,13 @@ def bert_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -204,7 +207,6 @@ def bert_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -234,14 +236,13 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + \ - (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -268,7 +269,7 @@ def custom_forward(*inputs): else: # intermediate stage always return dict return { - 'hidden_states': hidden_states, + "hidden_states": hidden_states, } @staticmethod @@ -295,10 +296,10 @@ def bert_for_pretraining_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -317,10 +318,6 @@ def bert_for_pretraining_forward( stage_index=stage_index, shard_config=shard_config, ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] @@ -345,11 +342,11 @@ def bert_for_pretraining_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict return { - 'hidden_states': hidden_states, + "hidden_states": hidden_states, } @staticmethod @@ -375,39 +372,39 @@ def bert_lm_head_model_forward( shard_config: ShardConfig = None, ): r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -428,11 +425,9 @@ def bert_lm_head_model_forward( stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, - shard_config=shard_config) + shard_config=shard_config, + ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -459,9 +454,9 @@ def bert_lm_head_model_forward( cross_attentions=outputs.cross_attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def bert_for_masked_lm_forward( @@ -484,20 +479,20 @@ def bert_for_masked_lm_forward( shard_config: ShardConfig = None, ): r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -525,7 +520,7 @@ def bert_for_masked_lm_forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -539,8 +534,8 @@ def bert_for_masked_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_next_sentence_prediction_forward( @@ -563,33 +558,33 @@ def bert_for_next_sentence_prediction_forward( ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see `input_ids` docstring). Indices should be in `[0, 1]`: + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, BertForNextSentencePrediction - >>> import torch + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch - >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - >>> outputs = model(**encoding, labels=torch.LongTensor([1])) - >>> logits = outputs.logits - >>> assert logits[0, 0] < logits[0, 1] # next sentence was random - ``` - """ + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ logger = logging.get_logger(__name__) if "next_sentence_label" in kwargs: @@ -603,26 +598,28 @@ def bert_for_next_sentence_prediction_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -644,9 +641,9 @@ def bert_for_next_sentence_prediction_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def bert_for_sequence_classification_forward( @@ -677,26 +674,28 @@ def bert_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -737,8 +736,8 @@ def bert_for_sequence_classification_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_token_classification_forward( @@ -767,26 +766,28 @@ def bert_for_token_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -810,8 +811,8 @@ def bert_for_token_classification_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_multiple_choice_forward( @@ -842,10 +843,10 @@ def bert_for_multiple_choice_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # in our pipeline design,input ids are copied for every stage and shouldn't be none @@ -857,8 +858,11 @@ def bert_for_multiple_choice_forward( attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -898,8 +902,8 @@ def bert_for_multiple_choice_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_question_answering_forward( @@ -936,26 +940,28 @@ def bert_for_question_answering_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -993,12 +999,11 @@ def bert_for_question_answering_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_bert_flash_attention_forward(): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -1064,7 +1069,7 @@ def forward( distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -1084,19 +1089,17 @@ def forward( if final_attention_mask is not None: batch_size, src_len = query_layer.size()[0], query_layer.size()[2] tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, - tgt_len).contiguous() + final_attention_mask = final_attention_mask.expand( + batch_size, self.num_attention_heads, src_len, tgt_len + ).contiguous() query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - context_layer = me_attention(query_layer, - key_layer, - value_layer, - attn_bias=final_attention_mask, - p=self.dropout.p, - scale=scale) + context_layer = me_attention( + query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale + ) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) @@ -1110,7 +1113,6 @@ def forward( def get_jit_fused_bert_self_output_forward(): - from transformers.models.bert.modeling_bert import BertSelfOutput def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -1123,7 +1125,6 @@ def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: tor def get_jit_fused_bert_output_forward(): - from transformers.models.bert.modeling_bert import BertOutput def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -1136,7 +1137,6 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1174,8 +1174,9 @@ def forward( `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1241,12 +1242,13 @@ def forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - embedding_output = split_forward_gather_backward(embedding_output, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + embedding_output = split_forward_gather_backward( + embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) encoder_outputs = self.encoder( embedding_output, @@ -1264,9 +1266,9 @@ def forward( sequence_output = encoder_outputs[0] # When sequence parallelism done, gather the output tensor in forward and split it in backward - sequence_output = gather_forward_split_backward(sequence_output, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + sequence_output = gather_forward_split_backward( + sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 69730fd3d254..00b2037fbdc8 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -1,12 +1,10 @@ -import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn def forward_fn(): - def forward( self, hidden_states: torch.Tensor, @@ -62,7 +60,6 @@ def forward( def get_blip2_flash_attention_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from colossalai.kernel.cuda_native import ColoAttention @@ -80,10 +77,9 @@ def forward( mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout.p, - scale=self.scale) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale + ) context_layer = attention(query_states, key_states, value_states) output = self.projection(context_layer) @@ -95,7 +91,6 @@ def forward( def get_jit_fused_blip2_QFormer_self_output_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -108,7 +103,6 @@ def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_ten def get_jit_fused_blip2_QFormer_output_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 66f24dc6088b..1bf87e80a461 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -30,9 +30,9 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: - - def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, - dtype: torch.dtype) -> torch.Tensor: + def build_bloom_alibi_tensor( + self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype + ) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -56,23 +56,23 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, num_heads = num_heads * world_size batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2**math.floor(math.log2(num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: - extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, - 1 + 2 * num_remaining_heads, - 2, - device=attention_mask.device, - dtype=torch.int32) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32 + ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention @@ -87,7 +87,7 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) offset = dist.get_rank(process_group) * num_heads_per_rank alibi = alibi.view(batch_size, num_heads, 1, seq_length) - alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + alibi = alibi[:, offset : num_heads_per_rank + offset, :, :] return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) else: return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) @@ -96,9 +96,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, class BloomPipelineForwards: - ''' + """ This class serves as a micro library for bloom pipeline forwards. - ''' + """ @staticmethod def bloom_model_forward( @@ -117,8 +117,7 @@ def bloom_model_forward( stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: - + ) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]: logger = logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: @@ -132,20 +131,21 @@ def bloom_model_forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # add warnings here if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -184,7 +184,8 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False if past_key_values is None: @@ -193,7 +194,7 @@ def bloom_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len + past_key_values_length = past_key_values[0][0].shape[2] # source_len seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: @@ -213,20 +214,20 @@ def bloom_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), - start=start_idx): + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -257,14 +258,13 @@ def custom_forward(*inputs): if use_cache is True: presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): # Add last hidden state @@ -277,7 +277,8 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( @@ -288,25 +289,27 @@ def custom_forward(*inputs): ) else: # always return dict for imediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod - def bloom_for_causal_lm_forward(self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - **deprecated_arguments): + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -328,30 +331,29 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -366,8 +368,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -381,8 +384,8 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_sequence_classification_forward( @@ -425,10 +428,10 @@ def bloom_for_sequence_classification_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( @@ -448,9 +451,6 @@ def bloom_for_sequence_classification_forward( shard_config=shard_config, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] # update batch size @@ -468,7 +468,8 @@ def bloom_for_sequence_classification_forward( sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -506,8 +507,8 @@ def bloom_for_sequence_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_token_classification_forward( @@ -550,10 +551,10 @@ def bloom_for_token_classification_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( @@ -573,9 +574,6 @@ def bloom_for_token_classification_forward( shard_config=shard_config, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -588,8 +586,9 @@ def bloom_for_token_classification_forward( labels = labels.to(logits.device) batch_size, seq_length = labels.shape loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), - labels.view(batch_size * seq_length)) + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) if not return_dict: output = (logits,) + transformer_outputs[2:] @@ -602,8 +601,8 @@ def bloom_for_token_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_question_answering_forward( @@ -638,10 +637,10 @@ def bloom_for_question_answering_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BloomPipelineForwards.bloom_model_forward( @@ -659,10 +658,6 @@ def bloom_for_question_answering_forward( stage_index=stage_index, shard_config=shard_config, ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -700,12 +695,11 @@ def bloom_for_question_answering_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_bloom_flash_attention_forward(enabel_jit_fused=False): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -723,7 +717,6 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, tgt_len, _ = query_layer.size() @@ -750,29 +743,35 @@ def forward( tgt_len = key_layer.size()[1] - attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), - dtype=torch.float32, - device=query_layer.device, - requires_grad=True) - attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, - kv_length) * self.beta - attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, - torch.finfo(torch.float32).min) - - context_layer = me_attention(query_layer, - key_layer, - value_layer, - attn_bias=attention_numerical_mask, - scale=self.inv_norm_factor, - p=self.attention_dropout.p) + attention_numerical_mask = torch.zeros( + (batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True, + ) + attention_numerical_mask = ( + attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + ) + attention_numerical_mask = torch.masked_fill( + attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min + ) + + context_layer = me_attention( + query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p, + ) context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) @@ -787,7 +786,6 @@ def forward( def get_jit_fused_bloom_attention_forward(): - from transformers.models.bloom.modeling_bloom import BloomAttention def forward( @@ -801,7 +799,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) @@ -867,8 +865,8 @@ def forward( output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) @@ -885,7 +883,6 @@ def forward( def get_jit_fused_bloom_mlp_forward(): - from transformers.models.bloom.modeling_bloom import BloomMLP def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: @@ -896,8 +893,8 @@ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): intermediate_output = intermediate_output + F.linear( - hidden_states[:, :, int(i * slices):int((i + 1) * slices)], - self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)], + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: intermediate_output = self.dense_4h_to_h(hidden_states) @@ -908,7 +905,6 @@ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) def get_jit_fused_bloom_gelu_forward(): - from transformers.models.bloom.modeling_bloom import BloomGelu from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction @@ -924,7 +920,6 @@ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): - from transformers import BloomModel def forward( @@ -951,8 +946,9 @@ def forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -986,7 +982,8 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # Compute alibi tensor: check build_alibi_tensor documentation @@ -1009,9 +1006,9 @@ def forward( ) # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -1020,7 +1017,6 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -1054,9 +1050,9 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Add last hidden state hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 16dcf87c8cfc..8934068d609c 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -1,26 +1,19 @@ """ PyTorch ChatGLM model. """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch -import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, -) +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -30,15 +23,15 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) else: if attention_mask is not None: attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -60,15 +53,15 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.hidden_size_per_partition, - num_heads=self.num_attention_heads_per_partition, - dropout=self.attention_dropout.p, - scale=scale) - context_layer = attention(query_layer, - key_layer, - value_layer, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attention = ColoAttention( + embed_dim=self.hidden_size_per_partition, + num_heads=self.num_attention_heads_per_partition, + dropout=self.attention_dropout.p, + scale=scale, + ) + context_layer = attention( + query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) context_layer = context_layer.permute(1, 0, -1).contiguous() @@ -78,7 +71,6 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ def get_jit_fused_glm_block_forward(): - from .chatglm2_6b.modeling_chatglm import GLMBlock def forward( @@ -129,9 +121,9 @@ def forward( class ChatGLMPipelineForwards: - ''' + """ This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. - ''' + """ @staticmethod def chatglm_model_forward( @@ -151,19 +143,20 @@ def chatglm_model_forward( shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if stage_manager.is_first_stage(): batch_size, seq_length = input_ids.shape @@ -174,12 +167,13 @@ def chatglm_model_forward( seq_length, batch_size = hidden_states.shape[:2] if self.pre_seq_len is not None: if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype) + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], - dim=-1) + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) @@ -196,37 +190,41 @@ def chatglm_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.encoder.gradient_checkpointing and self.encoder.training: - layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, - past_key_values[idx], use_cache) + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache + ) else: - layer_ret = layer(hidden_states, - full_attention_mask, - rotary_pos_emb, - kv_cache=past_key_values[idx], - use_cache=use_cache) + layer_ret = layer( + hidden_states, + full_attention_mask, + rotary_pos_emb, + kv_cache=past_key_values[idx], + use_cache=use_cache, + ) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -235,7 +233,8 @@ def chatglm_model_forward( hidden_states = self.encoder.final_layernorm(hidden_states) if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, @@ -243,28 +242,30 @@ def chatglm_model_forward( attentions=all_self_attentions, ) else: - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None): - logger = logging.get_logger(__name__) + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( self.transformer, input_ids=input_ids, @@ -312,7 +313,6 @@ def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGenera def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids, @@ -325,10 +325,11 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape @@ -365,9 +366,9 @@ def forward( # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] - inputs_embeds = split_forward_gather_backward(inputs_embeds, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group + ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, @@ -377,17 +378,21 @@ def forward( output_hidden_states=output_hidden_states, ) - hidden_states = gather_forward_split_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py index 3e78732be2da..bb774676a4d4 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py @@ -4,32 +4,34 @@ class ChatGLMConfig(PretrainedConfig): model_type = "chatglm" - def __init__(self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs): + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index a21ee0231422..3a8d90ec7328 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -37,10 +37,9 @@ import copy import math -import re import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -80,7 +79,6 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -100,7 +98,7 @@ def __init__(self, config: ChatGLMConfig): self.prefix_projection = config.prefix_projection if self.prefix_projection: # Use a two-layer MLP to encode the prefix - kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) self.trans = torch.nn.Sequential( torch.nn.Linear(kv_size, config.hidden_size), @@ -151,10 +149,9 @@ def split_tensor_along_last_dim( class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl @@ -174,7 +171,7 @@ def forward_impl( https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype, device=device) @@ -220,7 +217,6 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.elementwise_affine = True @@ -236,7 +232,6 @@ def forward(self, hidden_states: torch.Tensor): class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -250,7 +245,7 @@ def __init__(self, config: ChatGLMConfig, layer_number): # Per attention head and per partition values. self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads coeff = None @@ -267,15 +262,15 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) else: if attention_mask is not None: attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -307,8 +302,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) @@ -325,7 +320,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff - if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: attention_mask = torch.ones( output_size[0], 1, @@ -388,15 +383,16 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. - self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = (self.projection_size + - 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, @@ -459,18 +455,27 @@ def forward( ], dim=-1, ) - query_layer = query_layer.view(query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - key_layer = key_layer.view(key_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.view(value_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -504,10 +509,13 @@ def forward( self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) - key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, @@ -516,10 +524,13 @@ def forward( self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) - value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) # ================================== # core attention computation @@ -600,7 +611,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number - self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.fp32_residual_connection = config.fp32_residual_connection @@ -724,7 +735,8 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False all_self_attentions = None @@ -806,7 +818,7 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape - position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids def _set_gradient_checkpointing(self, module, value=False): @@ -843,7 +855,6 @@ def forward(self, input_ids): class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -860,8 +871,9 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): # Rotary positional embeddings self.seq_length = config.seq_length - rotary_dim = (config.hidden_size // - config.num_attention_heads if config.kv_channels is None else config.kv_channels) + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, @@ -891,7 +903,7 @@ def get_input_embeddings(self): return self.embedding.word_embeddings def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) past_key_values = past_key_values.view( batch_size, @@ -917,10 +929,11 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape @@ -966,12 +979,16 @@ def forward( ) if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -988,7 +1005,6 @@ def quantize(self, weight_bit_width: int): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config) @@ -1009,7 +1025,8 @@ def _update_model_kwargs_for_generation( ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format) + outputs, standardize_cache_format=standardize_cache_format + ) # update attention mask if "attention_mask" in model_kwargs: @@ -1067,7 +1084,7 @@ def forward( return_last_logit: Optional[bool] = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids=input_ids, @@ -1113,8 +1130,9 @@ def forward( ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], - beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -1122,10 +1140,13 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], Output shares the same memory storage as `past`. """ - return tuple(( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) for layer_past in past) + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) def process_response(self, response): response = response.strip() @@ -1180,7 +1201,7 @@ def chat( } inputs = self.build_inputs(tokenizer, query, history=history) outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] @@ -1227,14 +1248,14 @@ def stream_chat( attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) inputs["attention_mask"] = attention_mask for outputs in self.stream_generate( - **inputs, - past_key_values=past_key_values, - return_past_key_values=return_past_key_values, - **gen_kwargs, + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, ): if return_past_key_values: outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) if response and response[-1] != "�": response = self.process_response(response) @@ -1269,7 +1290,7 @@ def stream_generate( if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " @@ -1278,7 +1299,7 @@ def stream_generate( UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -1289,14 +1310,16 @@ def stream_generate( ) if input_ids_seq_length >= generation_config.max_length: - input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") - logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`.") + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) # 2. Set generation parameters if not already defined - logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) - stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = self._get_logits_processor( generation_config=generation_config, @@ -1306,8 +1329,9 @@ def stream_generate( logits_processor=logits_processor, ) - stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, - stopping_criteria=stopping_criteria) + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) logits_warper = self._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -1337,9 +1361,9 @@ def stream_generate( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation(outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) if return_past_key_values: yield input_ids, outputs.past_key_values diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 84deafefeadd..21f06393071d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -26,32 +26,32 @@ class GPT2PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPT2 models under pipeline setting. - ''' + """ @staticmethod def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -62,16 +62,16 @@ def gpt2_model_forward( # Preprocess passed in arguments # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if stage_manager.is_first_stage(): @@ -115,7 +115,7 @@ def gpt2_model_forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -156,7 +156,8 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None @@ -166,9 +167,9 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] @@ -186,7 +187,6 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -225,9 +225,9 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -241,8 +241,10 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -253,62 +255,65 @@ def custom_forward(*inputs): ) else: # always return dict for intermediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) @@ -337,25 +342,26 @@ def gpt2_lmhead_model_forward( @staticmethod def gpt2_double_heads_model_forward( - self: GPT2DoubleHeadsModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + self: GPT2DoubleHeadsModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -373,26 +379,28 @@ def gpt2_double_heads_model_forward( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) @@ -428,22 +436,23 @@ def gpt2_double_heads_model_forward( @staticmethod def gpt2_for_question_answering_forward( - self: GPT2ForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + self: GPT2ForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -459,24 +468,26 @@ def gpt2_for_question_answering_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} sequence_output = outputs[0] @@ -516,23 +527,24 @@ def gpt2_for_question_answering_forward( @staticmethod def gpt2_for_token_classification_forward( - self: GPT2ForTokenClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + self: GPT2ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -544,26 +556,28 @@ def gpt2_for_token_classification_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) @@ -588,23 +602,24 @@ def gpt2_for_token_classification_forward( @staticmethod def gpt2_for_sequence_classification_forward( - self: GPT2ForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + self: GPT2ForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -613,38 +628,41 @@ def gpt2_for_sequence_classification_forward( # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. # Please refer to original code of transformers for more details. - """ + """ logger = logging.get_logger(__name__) if input_ids is not None: batch_size, _ = input_ids.shape[:2] else: batch_size, _ = hidden_states.shape[:2] - assert (self.config.pad_token_id is not None - or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] logits = self.score(hidden_states) @@ -658,7 +676,8 @@ def gpt2_for_sequence_classification_forward( sequence_lengths = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -698,7 +717,6 @@ def gpt2_for_sequence_classification_forward( def get_gpt2_flash_attention_forward(): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -722,12 +740,12 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) @@ -759,15 +777,14 @@ def forward( attn_mask_type = AttnMaskType.padding flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - scale = value.size(-1)**-0.5 + scale = value.size(-1) ** -0.5 if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) # use coloattention - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.attn_dropout.p, - scale=scale) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) @@ -781,7 +798,6 @@ def forward( def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -799,8 +815,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -849,7 +866,7 @@ def forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -886,7 +903,8 @@ def forward( if use_cache: logger = logging.get_logger(__name__) logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None @@ -896,9 +914,9 @@ def forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel @@ -918,7 +936,6 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -962,9 +979,9 @@ def custom_forward(*inputs): hidden_states = hidden_states.to("cuda:" + str(k + 1)) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -974,8 +991,10 @@ def custom_forward(*inputs): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py index 6434348ef823..c92847a3fbcc 100644 --- a/colossalai/shardformer/modeling/jit.py +++ b/colossalai/shardformer/modeling/jit.py @@ -2,7 +2,6 @@ def get_dropout_add_func(): - from transformers.models.bloom.modeling_bloom import dropout_add def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: @@ -12,7 +11,6 @@ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, def get_jit_fused_dropout_add_func(): - from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: @@ -25,7 +23,6 @@ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, def get_jit_fused_gelu_forward_func(): - from colossalai.kernel.jit.bias_gelu import bias_gelu def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ff622c306c59..4b6c8342534a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -15,10 +15,10 @@ class LlamaPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ @staticmethod def llama_model_forward( @@ -39,8 +39,9 @@ def llama_model_forward( logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -69,13 +70,13 @@ def llama_model_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if past_key_values is not None: @@ -83,10 +84,9 @@ def llama_model_forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -94,16 +94,18 @@ def llama_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -121,7 +123,6 @@ def llama_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -169,7 +170,7 @@ def custom_forward(*inputs): attentions=all_self_attns, ) # always return dict for imediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def llama_for_causal_lm_forward( @@ -189,42 +190,43 @@ def llama_for_causal_lm_forward( stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -244,9 +246,6 @@ def llama_for_causal_lm_forward( stage_index=stage_index, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = outputs[0] @@ -276,8 +275,8 @@ def llama_for_causal_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def llama_for_sequence_classification_forward( @@ -307,10 +306,10 @@ def llama_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = LlamaPipelineForwards.llama_model_forward( @@ -388,16 +387,15 @@ def llama_for_sequence_classification_forward( ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_llama_flash_attention_forward(): - - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv @@ -453,16 +451,15 @@ def forward( if attention_mask != None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) attn_output = self.o_proj(attn_output) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index ad088f3702e5..e0978d38e110 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -21,16 +21,17 @@ class OPTPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of OPT models under pipeline setting. - ''' + """ @staticmethod def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] from transformers.models.opt.modeling_opt import _make_causal_mask + combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( @@ -42,10 +43,12 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, - tgt_len=input_shape[-1]).to(device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( + device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -79,17 +82,19 @@ def opt_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - ''' + """ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward - ''' + """ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import logging + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -133,10 +138,12 @@ def opt_model_forward( elif attention_mask.shape[1] != mask_seq_length: raise ValueError( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)") + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, - device, past_key_values_length) + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( + attention_mask, input_shape, _dtype, device, past_key_values_length + ) if stage_manager.is_first_stage(): pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) @@ -145,21 +152,22 @@ def opt_model_forward( if decoder.gradient_checkpointing and decoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # decoder layers @@ -173,7 +181,8 @@ def opt_model_forward( if attn_mask.size()[0] != (len(decoder.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") + f" {head_mask.size()[0]}." + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -195,7 +204,6 @@ def opt_model_forward( if decoder.gradient_checkpointing and decoder.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -250,7 +258,7 @@ def custom_forward(*inputs): attentions=all_self_attns, ) else: - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def opt_for_causal_lm_forward( @@ -275,8 +283,9 @@ def opt_for_causal_lm_forward( """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -319,8 +328,8 @@ def opt_for_causal_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def opt_for_sequence_classification_forward( @@ -348,19 +357,21 @@ def opt_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -377,7 +388,8 @@ def opt_for_sequence_classification_forward( sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -416,8 +428,8 @@ def opt_for_sequence_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def opt_for_question_answering_forward( @@ -443,19 +455,21 @@ def opt_for_question_answering_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -493,12 +507,11 @@ def opt_for_question_answering_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_opt_flash_attention_forward(): - from transformers.models.opt.modeling_opt import OPTAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -555,27 +568,27 @@ def forward( src_len = key_states.size(1) if layer_head_mask != None: if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) flash_attention_mask = None attn_mask_type = AttnMaskType.causal if attention_mask != None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout, - scale=self.scaling) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -584,7 +597,6 @@ def forward( def get_jit_fused_opt_decoder_layer_forward(): - from transformers.models.opt.modeling_opt import OPTDecoderLayer def forward( diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index c40c02ec411a..26e0b224d3ab 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -7,20 +7,23 @@ def forward_fn(): - def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, - -1).permute(2, 0, 3, 1, 4)) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w, - (height, width), (height, width)) + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) @@ -45,8 +48,8 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch def get_sam_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamAttention + try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -62,11 +65,9 @@ def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: batch, n_tokens, n_heads, c_per_head = hidden_states.shape return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) - def forward(self: SamAttention, - query: Tensor, - key: Tensor, - value: Tensor, - attention_similarity: Tensor = None) -> Tensor: + def forward( + self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None + ) -> Tensor: # Input projections query = self.q_proj(query) key = self.k_proj(key) @@ -96,8 +97,8 @@ def forward(self: SamAttention, def get_sam_vision_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamVisionAttention + try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -181,8 +182,11 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, - -1).permute(2, 0, 1, 3, 4)) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 1, 3, 4) + ) query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9cc071f91dfc..f67aa84e4e72 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -17,10 +17,10 @@ class T5PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of T5 models under pipeline setting. - ''' + """ @staticmethod def t5_stack_forward( @@ -44,7 +44,6 @@ def t5_stack_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward. # Please refer to original code of transformers for more details. @@ -52,16 +51,16 @@ def t5_stack_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if use_cache is True: if not in_decoder: @@ -69,7 +68,8 @@ def t5_stack_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False stage = stage_manager.stage @@ -97,7 +97,8 @@ def t5_stack_forward( else: err_msg_prefix = "decoder_" if in_decoder else "" raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -108,7 +109,8 @@ def t5_stack_forward( else: if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -153,7 +155,6 @@ def t5_stack_forward( start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): - past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] @@ -163,7 +164,6 @@ def t5_stack_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return tuple(module(*inputs, use_cache, output_attentions)) @@ -179,7 +179,7 @@ def custom_forward(*inputs): encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing + None, # past_key_value is always None with gradient checkpointing ) else: layer_outputs = layer_module( @@ -220,13 +220,17 @@ def custom_forward(*inputs): hidden_states = self.dropout(hidden_states) if not return_dict: - return tuple(v for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, @@ -236,10 +240,10 @@ def custom_forward(*inputs): ) else: return { - 'hidden_states': hidden_states, - 'position_bias': position_bias, - 'encoder_decoder_position_bias': encoder_decoder_position_bias, - 'backward_tensor_keys': ['hidden_states'] + "hidden_states": hidden_states, + "position_bias": position_bias, + "encoder_decoder_position_bias": encoder_decoder_position_bias, + "backward_tensor_keys": ["hidden_states"], } @staticmethod @@ -269,7 +273,6 @@ def t5_model_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward. # Please refer to original code of transformers for more details. @@ -287,16 +290,16 @@ def t5_model_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -322,10 +325,11 @@ def t5_model_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs @@ -360,23 +364,26 @@ def t5_model_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs if not return_dict: return decoder_outputs + encoder_hidden_states else: - return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states) + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) @staticmethod def t5_for_conditional_generation_forward( @@ -406,7 +413,6 @@ def t5_for_conditional_generation_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward. # Please refer to original code of transformers for more details. @@ -424,16 +430,16 @@ def t5_for_conditional_generation_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -460,10 +466,11 @@ def t5_for_conditional_generation_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs @@ -502,12 +509,13 @@ def t5_for_conditional_generation_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs sequence_output = decoder_outputs[0] @@ -530,13 +538,15 @@ def t5_for_conditional_generation_forward( output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput(loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states) + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) @staticmethod def t5_encoder_model_forward( @@ -562,26 +572,27 @@ def t5_encoder_model_forward( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = T5PipelineForwards.t5_stack_forward(self.encoder, - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - position_bias=position_bias, - encoder_decoder_position_bias=encoder_decoder_position_bias, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) return outputs def get_t5_flash_attention_forward(): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -655,19 +666,21 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project(hidden_states, self.k, key_value_states, - past_key_value[0] if past_key_value is not None else None) - value_states = project(hidden_states, self.v, key_value_states, - past_key_value[1] if past_key_value is not None else None) + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) if position_bias is None: if not self.has_relative_attention_bias: - position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), - device=query_states.device, - dtype=query_states.dtype) + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: @@ -676,10 +689,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1):, :] + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -689,12 +702,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=position_bias_masked, - p=self.dropout, - scale=1.0) + attn_output = me_attention( + query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output) @@ -708,7 +718,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): def get_jit_fused_T5_layer_ff_forward(): - from transformers.models.t5.modeling_t5 import T5LayerFF def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: @@ -721,7 +730,6 @@ def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: def get_T5_layer_self_attention_forward(): - from transformers.models.t5.modeling_t5 import T5LayerSelfAttention def forward( @@ -745,14 +753,13 @@ def forward( output_attentions=output_attentions, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs return forward def get_T5_layer_cross_attention_forward(): - from transformers.models.t5.modeling_t5 import T5LayerCrossAttention def forward( @@ -780,7 +787,7 @@ def forward( output_attentions=output_attentions, ) layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2ce52163ac32..2db83b912112 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,5 +1,5 @@ import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder @@ -17,7 +17,6 @@ def _encoder_forward( return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: - for i in range(start_idx, end_idx): layer_module = encoder.layer[i] @@ -26,7 +25,6 @@ def _encoder_forward( if encoder.gradient_checkpointing and encoder.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, False) @@ -54,7 +52,6 @@ def custom_forward(*inputs): def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling def pp_forward( @@ -69,19 +66,19 @@ def pp_forward( hidden_states: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict logger = logging.get_logger(__name__) # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # Prepare head mask if needed @@ -100,11 +97,13 @@ def pp_forward( if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) - embedding_output = self.embeddings(pixel_values, - bool_masked_pos=bool_masked_pos, - interpolate_pos_encoding=interpolate_pos_encoding) + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) else: - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" # Go through encoder if not stage_manager.is_last_stage(): @@ -117,7 +116,7 @@ def pp_forward( return_dict=return_dict, stage_manager=stage_manager, ) - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} else: encoder_outputs = _encoder_forward( encoder=self.encoder, @@ -149,7 +148,6 @@ def pp_forward( def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.vit.modeling_vit import ImageClassifierOutput @@ -173,7 +171,9 @@ def pp_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not stage_manager.is_first_stage(): - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" outputs = self.vit( pixel_values, @@ -234,7 +234,6 @@ def pp_forward( def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - import math import torch.nn as nn @@ -286,19 +285,24 @@ def pp_forward( raise ValueError( "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " "the reconstructed image has the same dimensions as the input." - f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) if not stage_manager.is_first_stage(): - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" - - outputs = self.vit(pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - hidden_states=hidden_states) + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) if not stage_manager.is_last_stage(): return outputs else: @@ -317,9 +321,12 @@ def pp_forward( if bool_masked_pos is not None: size = self.config.image_size // self.config.patch_size bool_masked_pos = bool_masked_pos.reshape(-1, size, size) - mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, - 1).repeat_interleave(self.config.patch_size, - 2).unsqueeze(1).contiguous()) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels @@ -338,7 +345,6 @@ def pp_forward( def get_vit_flash_self_attention_forward(): - from transformers.models.vit.modeling_vit import ViTSelfAttention from colossalai.kernel.cuda_native import ColoAttention @@ -348,22 +354,24 @@ def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_si x = x.view(new_x_shape) return x - def forward(self: ViTSelfAttention, - hidden_states: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + def forward( + self: ViTSelfAttention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) - value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads, - self.attention_head_size) + value_layer = transpose_for_scores( + self.value(hidden_states), self.num_attention_heads, self.attention_head_size + ) query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) scale = 1.0 / math.sqrt(self.attention_head_size) - attention = ColoAttention(embed_dim=self.all_head_size, - num_heads=self.num_attention_heads, - dropout=self.dropout.p, - scale=scale) + attention = ColoAttention( + embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale + ) context_layer = attention(query_layer, key_layer, value_layer) outputs = (context_layer,) @@ -374,7 +382,6 @@ def forward(self: ViTSelfAttention, def get_jit_fused_vit_output_forward(): - from transformers.models.vit.modeling_vit import ViTOutput def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 62f8f7b4763e..ef59dbcee680 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,6 +1,6 @@ import logging import random -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -24,7 +24,6 @@ def get_whisper_flash_attention_forward(): - from transformers.models.whisper.modeling_whisper import WhisperAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -53,8 +52,11 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[1] == key_value_states.shape[1]): + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] @@ -89,8 +91,10 @@ def forward( src_len = key_states.size(1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) attn_type = None flash_attention_mask = None @@ -104,15 +108,12 @@ def forward( flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) attn_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout, - scale=self.scaling) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_type) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type + ) attn_output = self.out_proj(attn_output) @@ -122,7 +123,6 @@ def forward( def get_jit_fused_whisper_encoder_layer_forward(): - from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer def forward( @@ -160,8 +160,9 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) - if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) @@ -176,7 +177,6 @@ def forward( def get_jit_fused_whisper_decoder_layer_forward(): - from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer def forward( @@ -269,10 +269,10 @@ def forward( class WhisperPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ @staticmethod def whisper_encoder_forward( @@ -315,15 +315,16 @@ def whisper_encoder_forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - logger = logging.get_logger(__name__) + logging.get_logger(__name__) stage = stage_manager.stage - at_first_stage = (stage == 0) - at_last_stage = (stage == decoder_starting_stage - 1) + at_first_stage = stage == 0 + at_last_stage = stage == decoder_starting_stage - 1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Process inputs if at the first stage of encoder. @@ -349,7 +350,8 @@ def whisper_encoder_forward( else: if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -360,13 +362,12 @@ def whisper_encoder_forward( encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): # skip the layer + if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -398,12 +399,12 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) else: - return {'hidden_states': hidden_states, 'head_mask': head_mask} + return {"hidden_states": hidden_states, "head_mask": head_mask} @staticmethod def whisper_decoder_forward( @@ -483,12 +484,13 @@ def whisper_decoder_forward( """ logger = logging.get_logger(__name__) stage = stage_manager.stage - at_first_stage = (stage == decoder_starting_stage) - at_last_stage = (stage == stage_manager.num_stages - 1) + at_first_stage = stage == decoder_starting_stage + at_last_stage = stage == stage_manager.num_stages - 1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -503,7 +505,8 @@ def whisper_decoder_forward( if attn_mask is not None: assert attn_mask.size()[0] == (len(self.layers)), ( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") + f" {head_mask.size()[0]}." + ) # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 @@ -529,8 +532,9 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -543,14 +547,15 @@ def whisper_decoder_forward( use_cache = False else: - if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) input_shape = hidden_states.size()[:-1] - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -569,7 +574,6 @@ def whisper_decoder_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) @@ -581,10 +585,10 @@ def custom_forward(*inputs): hidden_states, attention_mask, encoder_hidden_states, - None, # encoder attention mask + None, # encoder attention mask head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, # past_key_value + None, # past_key_value ) else: layer_outputs = decoder_layer( @@ -592,8 +596,9 @@ def custom_forward(*inputs): attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -617,8 +622,10 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -629,9 +636,9 @@ def custom_forward(*inputs): else: return { - 'head_mask': head_mask, - 'cross_attn_head_mask': cross_attn_head_mask, - 'hidden_states': hidden_states, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "hidden_states": hidden_states, } @staticmethod @@ -678,23 +685,24 @@ def whisper_model_forward( ```""" # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - logger = logging.get_logger(__name__) + logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict in_decoder = stage_manager.stage >= decoder_starting_stage @@ -712,14 +720,15 @@ def whisper_model_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], @@ -738,27 +747,29 @@ def whisper_model_forward( raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded Whisper forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs if not return_dict: @@ -830,36 +841,39 @@ def whisper_for_conditional_generation_forward( if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, - self.config.decoder_start_token_id) + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) in_decoder = stage_manager.stage >= decoder_starting_stage at_last_decoder_stage = stage_manager.is_last_stage() - outputs = WhisperPipelineForwards.whisper_model_forward(self.model, - input_features, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + outputs = WhisperPipelineForwards.whisper_model_forward( + self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) if not in_decoder: return outputs if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - outputs['encoder_hidden_states'] = encoder_hidden_states + outputs["encoder_hidden_states"] = encoder_hidden_states return outputs lm_logits = self.proj_out(outputs[0]) @@ -909,8 +923,9 @@ def whisper_for_audio_classification_forward( Please refer to original code of transformers for more details. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # audio_classification only holds encoder diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 49613ffb37e0..3bea91ef94dc 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -18,6 +18,7 @@ class PolicyLocation: file_name (str): The file name of the policy under colossalai.shardformer.policies class_name (str): The class name of the policy class """ + file_name: str class_name: str @@ -27,121 +28,142 @@ class PolicyLocation: # we will allow the user to only import the policy file needed _POLICY_LIST = { # BERT - "transformers.models.bert.modeling_bert.BertModel": - PolicyLocation(file_name="bert", class_name="BertModelPolicy"), - "transformers.models.bert.modeling_bert.BertForPreTraining": - PolicyLocation(file_name="bert", class_name="BertForPreTrainingPolicy"), - "transformers.models.bert.modeling_bert.BertLMHeadModel": - PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), - "transformers.models.bert.modeling_bert.BertForMaskedLM": - PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), - "transformers.models.bert.modeling_bert.BertForSequenceClassification": - PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), - "transformers.models.bert.modeling_bert.BertForTokenClassification": - PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"), - "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": - PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), - "transformers.models.bert.modeling_bert.BertForMultipleChoice": - PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), - "transformers.models.bert.modeling_bert.BertForQuestionAnswering": - PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"), - + "transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"), + "transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation( + file_name="bert", class_name="BertForPreTrainingPolicy" + ), + "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation( + file_name="bert", class_name="BertLMHeadModelPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation( + file_name="bert", class_name="BertForMaskedLMPolicy" + ), + "transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation( + file_name="bert", class_name="BertForSequenceClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation( + file_name="bert", class_name="BertForTokenClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation( + file_name="bert", class_name="BertForNextSentencePredictionPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation( + file_name="bert", class_name="BertForMultipleChoicePolicy" + ), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation( + file_name="bert", class_name="BertForQuestionAnsweringPolicy" + ), # LLaMA - "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": - PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), - "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": - PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), - + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaForCausalLMPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation( + file_name="llama", class_name="LlamaForSequenceClassificationPolicy" + ), # T5 - "transformers.models.t5.modeling_t5.T5Model": - PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), - "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": - PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"), - "transformers.models.t5.modeling_t5.T5EncoderModel": - PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), - + "transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation( + file_name="t5", class_name="T5ForConditionalGenerationPolicy" + ), + "transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 - "transformers.models.gpt2.modeling_gpt2.GPT2Model": - PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": - PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": - PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": - PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": - PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": - PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), - + "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation( + file_name="gpt2", class_name="GPT2LMHeadModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation( + file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation( + file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" + ), # ViT - "transformers.models.vit.modeling_vit.ViTModel": - PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), - "transformers.models.vit.modeling_vit.ViTForImageClassification": - PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), - "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": - PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), - + "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( + file_name="vit", class_name="ViTForImageClassificationPolicy" + ), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation( + file_name="vit", class_name="ViTForMaskedImageModelingPolicy" + ), # OPT - "transformers.models.opt.modeling_opt.OPTModel": - PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), - "transformers.models.opt.modeling_opt.OPTForCausalLM": - PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), - "transformers.models.opt.modeling_opt.OPTForSequenceClassification": - PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), - "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": - PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), - + "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation( + file_name="opt", class_name="OPTForCausalLMPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation( + file_name="opt", class_name="OPTForSequenceClassificationPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation( + file_name="opt", class_name="OPTForQuestionAnsweringPolicy" + ), # Bloom - "transformers.models.bloom.modeling_bloom.BloomModel": - PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForCausalLM": - PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": - PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": - PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": - PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), - + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomForCausalLMPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation( + file_name="bloom", class_name="BloomForSequenceClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation( + file_name="bloom", class_name="BloomForTokenClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation( + file_name="bloom", class_name="BloomForQuestionAnsweringPolicy" + ), # Whisper - "transformers.models.whisper.modeling_whisper.WhisperModel": - PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"), - "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": - PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"), - "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": - PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"), - + "transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation( + file_name="whisper", class_name="WhisperModelPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation( + file_name="whisper", class_name="WhisperForConditionalGenerationPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation( + file_name="whisper", class_name="WhisperForAudioClassificationPolicy" + ), # Sam - "transformers.models.sam.modeling_sam.SamModel": - PolicyLocation(file_name="sam", class_name="SamModelPolicy"), - + "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), # Blip2 - "transformers.models.blip_2.modeling_blip_2.Blip2Model": - PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), - "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": - PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), - + "transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation( + file_name="blip2", class_name="Blip2ModelPolicy" + ), + "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation( + file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" + ), # ChatGLM - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMModelPolicy" + ), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" + ), } _INFER_POLICY_LIST = { # LlaMa - "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), # Bloom - "transformers.models.bloom.modeling_bloom.BloomModel": - PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForCausalLM": - PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), } @@ -163,9 +185,9 @@ def _fullname(obj): """ klass = obj.__class__ module = klass.__module__ - if module == 'builtins': - return klass.__qualname__ # avoid outputs like 'builtins.str' - return module + '.' + klass.__qualname__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 961c6a5259fe..e7f199129a00 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -106,14 +106,12 @@ def config_sanity_check(self): This method is made abstractmethod with no default implementation because we want to the policy writer to take note of the feature supported by his/her model and policy. """ - pass @abstractmethod def preprocess(self) -> nn.Module: r""" Perform some preprocessing of the model, like reshaping the embedding layer. """ - pass @abstractmethod def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -122,7 +120,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module will be transformed. """ - pass @abstractmethod def postprocess(self) -> nn.Module: @@ -130,13 +127,13 @@ def postprocess(self) -> nn.Module: Perform some postprocessing of the model, like binding the weight of embedding layer with the classifier layer """ - pass def append_or_create_submodule_replacement( - self, description: Union[SubModuleReplacementDescription, - List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module], - ModulePolicyDescription], - target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + self, + description: Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Append or create a new submodule replacement description to the policy for the given key. @@ -161,8 +158,11 @@ def append_or_create_submodule_replacement( return policy def append_or_create_method_replacement( - self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription], - target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + self, + description: Dict[str, Callable], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Append or create a new method replacement description to the policy for the given key. @@ -199,9 +199,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: @staticmethod def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages - - """ + """Divide layers into stages""" quotient = num_layers // num_stages remainder = num_layers % num_stages diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index a141b7bd8fdf..14146de158ae 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -7,7 +7,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, @@ -19,14 +18,20 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', - 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', - 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy' + "BertPolicy", + "BertModelPolicy", + "BertForPreTrainingPolicy", + "BertLMdHeadModelPolicy", + "BertForMaskedLMPolicy", + "BertForNextSentencePredictionPolicy", + "BertForSequenceClassificationPolicy", + "BertForTokenClassificationPolicy", + "BertForMultipleChoicePolicy", + "BertForQuestionAnsweringPolicy", ] class BertPolicy(Policy): - def config_sanity_check(self): pass @@ -58,136 +63,140 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - - policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ) - ]) + policy[BertLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=BertModel) + target_key=BertModel, + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BertLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BertLayer, + ) # handle embedding layer self.append_or_create_submodule_replacement( - description=[SubModuleReplacementDescription( - suffix="LayerNorm", - target_module=col_nn.FusedLayerNorm, - )], + description=[ + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], policy=policy, - target_key=BertEmbeddings) + target_key=BertEmbeddings, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_bert_flash_attention_forward(), - }, - policy=policy, - target_key=BertSelfAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_bert_flash_attention_forward(), + }, + policy=policy, + target_key=BertSelfAttention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bert_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BertSelfOutput) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bert_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BertOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertOutput, + ) return policy @@ -196,31 +205,37 @@ def add_lm_head_policy(self, base_policy): # optimize for tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) # optimize with fused normalization if self.shard_config.enable_fused_normalization: # Handle bert lm prediction head - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) return base_policy def add_lm_prediction_policy(self, base_policy): from transformers.models.bert.modeling_bert import BertLMPredictionHead + method_replacement = { - '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict, - '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict, + "_save_to_state_dict": col_nn.ParallelModule._save_to_state_dict, + "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, } - self.append_or_create_method_replacement(description=method_replacement, - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_method_replacement( + description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + ) return base_policy def postprocess(self): @@ -228,7 +243,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "BertModel": @@ -239,15 +254,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -255,7 +268,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'BertModel': + if self.model.__class__.__name__ == "BertModel": module = self.model else: module = self.model.bert @@ -275,17 +288,17 @@ def get_held_layers(self) -> List[Module]: # BertModel class BertModelPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertModel, - new_forward=BertPipelineForwards.bert_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -300,7 +313,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForPreTraining class BertForPreTrainingPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -309,10 +321,13 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForPreTraining, - new_forward=BertPipelineForwards.bert_for_pretraining_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForPreTraining, + new_forward=BertPipelineForwards.bert_for_pretraining_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: @@ -329,16 +344,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: model.bert.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight - }] + return [ + { + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight, + } + ] return [] # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -347,10 +363,11 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertLMHeadModel, - new_forward=BertPipelineForwards.bert_lm_head_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -368,16 +385,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] return [] # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -386,10 +404,11 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForMaskedLM, - new_forward=BertPipelineForwards.bert_for_masked_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -407,16 +426,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] return [] # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -427,19 +447,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ + BertForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForSequenceClassification, - new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForSequenceClassification, + new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, + policy=policy, + ) return policy @@ -461,7 +484,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForTokenClassification class BertForTokenClassificationPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -472,19 +494,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ + BertForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForTokenClassification, - new_forward=BertPipelineForwards.bert_for_token_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForTokenClassification, + new_forward=BertPipelineForwards.bert_for_token_classification_forward, + policy=policy, + ) return policy @@ -506,17 +531,19 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, - new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForNextSentencePrediction, + new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, + policy=policy, + ) return policy @@ -537,7 +564,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -548,19 +574,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForMultipleChoice: - ModulePolicyDescription(sub_module_replacement=[ + BertForMultipleChoice: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForMultipleChoice, - new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForMultipleChoice, + new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, + policy=policy, + ) return policy @@ -581,17 +610,19 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class BertForQuestionAnsweringPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers.models.bert.modeling_bert import BertForQuestionAnswering + policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForQuestionAnswering, - new_forward=BertPipelineForwards.bert_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForQuestionAnswering, + new_forward=BertPipelineForwards.bert_for_question_answering_forward, + policy=policy, + ) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 2e5388ab0490..997643d1a911 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -1,8 +1,5 @@ -import torch.nn as nn - import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.blip2 import ( forward_fn, get_blip2_flash_attention_forward, @@ -12,11 +9,10 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['BlipPolicy', 'BlipModelPolicy'] +__all__ = ["BlipPolicy", "BlipModelPolicy"] class BlipPolicy(Policy): - def config_sanity_check(self): pass @@ -48,263 +44,293 @@ def module_policy(self): policy = {} if self.shard_config.enable_tensor_parallelism: - policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.num_heads": - self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.embed_dim": - self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="self_attn.qkv", - target_module=col_nn.FusedLinear1D_Col, - kwargs={ - "n_fused": 3, - }), - SubModuleReplacementDescription( - suffix="self_attn.projection", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) + policy[Blip2EncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attn.embed_dim": self.model.config.vision_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) - policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[Blip2QFormerModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) - policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.attention.num_attention_heads": - self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.attention.all_head_size": - self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate_query.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output_query.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output_query.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + policy[Blip2QFormerLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) - policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ) - ]) + policy[OPTDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.text_config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.text_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) - policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="model.decoder.embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, - ), - ]) + policy[OPTForCausalLM] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ), + ] + ) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle Blip2EncoderLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2EncoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2EncoderLayer, + ) # handle Blip2VisionModel layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="post_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2VisionModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2VisionModel, + ) # handle Blip2VisionModel layer self.append_or_create_submodule_replacement( - description=[SubModuleReplacementDescription( - suffix="layernorm", - target_module=col_nn.FusedLayerNorm, - )], + description=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], policy=policy, - target_key=Blip2QFormerModel) + target_key=Blip2QFormerModel, + ) # handle Blip2QFormerLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output_query.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2QFormerLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2QFormerLayer, + ) # handle OPTForCausalLM layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="model.decoder.final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=OPTForCausalLM) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTForCausalLM, + ) # handle OPTDecoderLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_blip2_flash_attention_forward(), - }, - policy=policy, - target_key=Blip2Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_blip2_flash_attention_forward(), + }, + policy=policy, + target_key=Blip2Attention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=Blip2QFormerSelfOutput) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_blip2_QFormer_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=Blip2QFormerOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerOutput, + ) return policy @@ -314,13 +340,11 @@ def postprocess(self): # Blip2Model class Blip2ModelPolicy(BlipPolicy): - def __init__(self) -> None: super().__init__() # Blip2ForConditionalGeneration class Blip2ForConditionalGenerationPolicy(BlipPolicy): - def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7c418d02bcb6..13b9dd31345d 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List import torch.nn as nn from torch import Tensor @@ -7,7 +7,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, @@ -22,7 +21,6 @@ class BloomPolicy(Policy): - def config_sanity_check(self): pass @@ -47,39 +45,41 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'overlap': overlap - }), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - kwargs={'seq_parallel': use_sequence_parallel}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'overlap': overlap - }), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - kwargs={'seq_parallel': use_sequence_parallel}), - ]) + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + ], + ) policy[BloomModel] = ModulePolicyDescription( attribute_replacement={ @@ -93,72 +93,86 @@ def module_policy(self): suffix="word_embeddings", target_module=col_nn.VocabParallelEmbedding1D, ) - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: # handle bloom model - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="word_embeddings_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BloomModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomModel, + ) # handle bloom block - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BloomBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomBlock, + ) if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=BloomModel) + target_key=BloomModel, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_flash_attention_forward(), + "dropout_add": get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) # enable jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_mlp_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BloomMLP) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_gelu_forward(), - 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }, - policy=policy, - target_key=BloomGelu) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_mlp_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomMLP, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_gelu_forward(), + "bloom_gelu_forward": get_jit_fused_gelu_forward_func(), + }, + policy=policy, + target_key=BloomGelu, + ) return policy @@ -167,7 +181,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "BloomModel": @@ -178,22 +192,20 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'BloomModel': + if self.model.__class__.__name__ == "BloomModel": module = self.model else: module = self.model.transformer @@ -213,17 +225,17 @@ def get_held_layers(self) -> List[Module]: class BloomModelPolicy(BloomPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomModel, - new_forward=BloomPipelineForwards.bloom_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomModel, new_forward=BloomPipelineForwards.bloom_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -234,26 +246,29 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''no shared params in bloom model''' + """no shared params in bloom model""" return [] class BloomForCausalLMPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForCausalLM + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForCausalLM) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForCausalLM, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForCausalLM, - new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForCausalLM, new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -269,29 +284,36 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): # tie weights - return [{ - 0: bloom_model.transformer.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight - }] + return [ + { + 0: bloom_model.transformer.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight, + } + ] return [] class BloomForSequenceClassificationPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForSequenceClassification) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForSequenceClassification, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForSequenceClassification, - new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForSequenceClassification, + new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: @@ -308,28 +330,32 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class BloomForTokenClassificationPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="classifier", - target_module=col_nn.Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ], - policy=policy, - target_key=BloomForTokenClassification) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForTokenClassification, - new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForTokenClassification, + new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, + policy=policy, + ) return policy @@ -351,11 +377,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering + policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForQuestionAnswering, - new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForQuestionAnswering, + new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 44898847056a..3c27c848e738 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -1,19 +1,12 @@ from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor -from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, -) +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_forward_fn, @@ -23,11 +16,10 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] class ChatGLMPolicy(Policy): - def config_sanity_check(self): pass @@ -44,12 +36,11 @@ def preprocess(self): if self.pipeline_stage_manager is not None: # the batch_size_dim is bounded to Model bsz_dim = 1 - setattr(self.model, 'batch_size_dim', bsz_dim) + setattr(self.model, "batch_size_dim", bsz_dim) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock policy = {} @@ -57,111 +48,129 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embedding.word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ]) + policy[ChatGLMModel] = ModulePolicyDescription( + attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, + "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads + ) + // self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads * 3 + ) + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels + * self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, }, param_replacement=[], sub_module_replacement=[ - SubModuleReplacementDescription(suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0, - 'overlap': overlap - }), - SubModuleReplacementDescription(suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0 - }), + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, ), - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: - - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), - SubModuleReplacementDescription(suffix="post_attention_layernorm", - target_module=col_nn.FusedLayerNorm) - ], - policy=policy, - target_key=GLMBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="encoder.final_layernorm", - target_module=col_nn.FusedLayerNorm) - ], - policy=policy, - target_key=ChatGLMModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) else: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), - SubModuleReplacementDescription(suffix="post_attention_layernorm", - target_module=col_nn.FusedRMSNorm) - ], - policy=policy, - target_key=GLMBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="encoder.final_layernorm", - target_module=col_nn.FusedRMSNorm) - ], - policy=policy, - target_key=ChatGLMModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_flash_core_attention_forward(), - }, - policy=policy, - target_key=CoreAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_flash_core_attention_forward(), + }, + policy=policy, + target_key=CoreAttention, + ) # use sequence parallel if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=ChatGLMModel) + target_key=ChatGLMModel, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_glm_block_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=GLMBlock) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_glm_block_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=GLMBlock, + ) return policy @@ -172,7 +181,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'ChatGLMModel': + if self.model.__class__.__name__ == "ChatGLMModel": module = self.model else: module = self.model.transformer @@ -195,11 +204,11 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'ChatGLMModel': + if self.model.__class__.__name__ == "ChatGLMModel": module = self.model else: module = self.model.transformer @@ -207,29 +216,26 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class ChatGLMModelPolicy(ChatGLMPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Model + pass policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=ChatGLMModel, - new_forward=ChatGLMPipelineForwards.chatglm_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -241,14 +247,15 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): - def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration, - new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5093fd469af8..6f46bfc7ef9f 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,18 +5,20 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', - 'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy' + "GPT2Policy", + "GPT2ModelPolicy", + "GPT2LMHeadModelPolicy", + "GPT2DoubleHeadsModelPolicy", + "GPT2ForTokenClassificationPolicy", + "GPT2ForSequenceClassificationPolicy", ] class GPT2Policy(Policy): - def config_sanity_check(self): pass @@ -40,16 +42,18 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="drop", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) policy[GPT2Block] = ModulePolicyDescription( attribute_replacement={ @@ -61,31 +65,27 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "n_fused": 3, "seq_parallel": use_sequence_parallel, - "overlap": overlap }, ), - SubModuleReplacementDescription(suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }), SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "n_fused": 1, "seq_parallel": use_sequence_parallel, - "overlap": overlap }, ), - SubModuleReplacementDescription(suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }), SubModuleReplacementDescription( suffix="attn.attn_dropout", target_module=col_nn.DropoutForParallelInput, @@ -98,39 +98,46 @@ def module_policy(self): suffix="mlp.dropout", target_module=col_nn.DropoutForParallelInput, ), - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - policy=policy, - target_key=GPT2Model) - - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="ln_1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="ln_2", + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - SubModuleReplacementDescription(suffix="ln_cross_attn", - target_module=col_nn.FusedLayerNorm, - ignore_if_not_exist=True) - ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Model, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=GPT2Block, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_gpt2_flash_attention_forward(), - }, - policy=policy, - target_key=GPT2Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_gpt2_flash_attention_forward(), + }, + policy=policy, + target_key=GPT2Attention, + ) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} @@ -144,7 +151,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer @@ -164,11 +171,11 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer @@ -176,18 +183,15 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) # GPT2Model class GPT2ModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -197,9 +201,9 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2Model, - new_forward=GPT2PipelineForwards.gpt2_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -212,7 +216,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2LMHeadModel class GPT2LMHeadModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -223,18 +226,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -244,7 +251,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -256,7 +263,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -267,18 +273,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2DoubleHeadsModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, - new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy, + ) return module_policy @@ -295,7 +305,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -307,7 +317,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2ForQuestionAnswering class GPT2ForQuestionAnsweringPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -317,9 +326,11 @@ def module_policy(self): module_policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, - new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy, + ) return module_policy @@ -330,13 +341,12 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared_params in gpt2 for QA.''' + """No shared_params in gpt2 for QA.""" return [] # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -347,17 +357,20 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2ForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ + GPT2ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) - ]) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, - new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -374,7 +387,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -384,9 +396,11 @@ def module_policy(self): module_policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, - new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index cc131e8168fc..099995acb440 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -11,11 +11,10 @@ from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] class LlamaPolicy(Policy): - def config_sanity_check(self): pass @@ -40,15 +39,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -80,45 +79,53 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - ) + ), ], ) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=LlamaModel) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel, + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", target_module=FusedRMSNorm, ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=FusedRMSNorm, - ) - ], - policy=policy, - target_key=LlamaDecoderLayer) - - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=LlamaModel) + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_llama_flash_attention_forward(), - }, - policy=policy, - target_key=LlamaAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(), + }, + policy=policy, + target_key=LlamaAttention, + ) return policy @@ -127,7 +134,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "LlamaModel": @@ -137,10 +144,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -148,7 +155,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'LlamaModel': + if self.model.__class__.__name__ == "LlamaModel": module = self.model else: module = self.model.model @@ -167,18 +174,18 @@ def get_held_layers(self) -> List[Module]: class LlamaModelPolicy(LlamaPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaModel, - new_forward=LlamaPipelineForwards.llama_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -192,7 +199,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class LlamaForCausalLMPolicy(LlamaPolicy): - def module_policy(self): from transformers import LlamaForCausalLM @@ -201,19 +207,21 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) return policy @@ -228,18 +236,21 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if id(llama_model.embed_tokens.weight) == id( - self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight - }] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] return [] class LlamaForSequenceClassificationPolicy(LlamaPolicy): - def module_policy(self): from transformers import LlamaForSequenceClassification @@ -248,19 +259,23 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { - LlamaForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) # to be confirmed if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForSequenceClassification, - new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index abe491bfaace..5739d21a3903 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -13,13 +13,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', - 'OPTForQuestionAnsweringPolicy' + "OPTPolicy", + "OPTModelPolicy", + "OPTForCausalLMPolicy", + "OPTForSequenceClassificationPolicy", + "OPTForQuestionAnsweringPolicy", ] class OPTPolicy(Policy): - def config_sanity_check(self): pass @@ -45,79 +47,94 @@ def module_policy(self): warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + policy[OPTDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ] + ) + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ), + ] + ) + + policy[OPTAttention] = ModulePolicyDescription( + attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="self_attn_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True), - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + policy=policy, + target_key=OPTDecoder, + ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_opt_flash_attention_forward(), - }, - policy=policy, - target_key=OPTAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_opt_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_opt_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer, + ) return policy @@ -128,7 +145,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'OPTModel': + if self.model.__class__.__name__ == "OPTModel": module = self.model.decoder else: module = self.model.model.decoder @@ -149,24 +166,23 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'OPTModel': + if self.model.__class__.__name__ == "OPTModel": module = self.model.decoder else: module = self.model.model.decoder layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) class OPTModelPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -175,9 +191,9 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTModel, - new_forward=OPTPipelineForwards.opt_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -189,20 +205,22 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OPTForCausalLMPolicy(OPTPolicy): - def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=OPTForCausalLM, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForCausalLM, - new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy + ) return policy @@ -223,7 +241,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def postprocess(self): if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = { - 'model.decoder.embed_tokens': 'lm_head', + "model.decoder.embed_tokens": "lm_head", } for k, v in binding_map.items(): @@ -235,7 +253,6 @@ def postprocess(self): class OPTForSequenceClassificationPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -244,9 +261,11 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForSequenceClassification, - new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForSequenceClassification, + new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, + policy=policy, + ) return policy @@ -262,7 +281,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OPTForQuestionAnsweringPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -271,9 +289,11 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForQuestionAnswering, - new_forward=OPTPipelineForwards.opt_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForQuestionAnswering, + new_forward=OPTPipelineForwards.opt_for_question_answering_forward, + policy=policy, + ) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 9753d5a737b9..58a8500e3863 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,16 +1,12 @@ -import torch.nn as nn - import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['SamPolicy', 'SamModelPolicy'] +__all__ = ["SamPolicy", "SamModelPolicy"] class SamPolicy(Policy): - def config_sanity_check(self): pass @@ -20,7 +16,6 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( SamAttention, - SamFeedForward, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -30,36 +25,37 @@ def module_policy(self): policy = {} if self.shard_config.enable_tensor_parallelism: - policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={ - "attn.num_attention_heads": - self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.qkv", - target_module=col_nn.FusedLinear1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.lin1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.lin2", - target_module=col_nn.Linear1D_Row, - ) - ]) + policy[SamVisionLayer] = ModulePolicyDescription( + attribute_replacement={ + "attn.num_attention_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( attribute_replacement={ - "self_attn.num_attention_heads": - self.model.config.mask_decoder_config.num_attention_heads // - self.shard_config.tensor_parallel_size, + "self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, }, sub_module_replacement=[ SubModuleReplacementDescription( @@ -118,97 +114,112 @@ def module_policy(self): suffix="cross_attn_image_to_token.out_proj", target_module=col_nn.Linear1D_Row, ), - ]) - policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={ - "final_attn_token_to_image.num_attention_heads": - self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.out_proj", - target_module=col_nn.Linear1D_Row, - ) - ]) + ], + ) + policy[SamTwoWayTransformer] = ModulePolicyDescription( + attribute_replacement={ + "final_attn_token_to_image.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ], + ) # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` - policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={ - "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) - }, - method_replacement={"forward": forward_fn()}, - sub_module_replacement=[]) + policy[SamVisionAttention] = ModulePolicyDescription( + attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[], + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle SamVisionLayer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamVisionLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamVisionLayer, + ) # Handle SamTwoWayAttentionBlock - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm3", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm4", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamTwoWayAttentionBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamTwoWayAttentionBlock, + ) # Handle SamTwoWayTransformer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm_final_attn", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamTwoWayTransformer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_sam_flash_attention_forward(), + }, + policy=policy, + target_key=SamAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_sam_vision_flash_attention_forward(), + }, + policy=policy, + target_key=SamVisionAttention, + ) return policy @@ -218,6 +229,5 @@ def postprocess(self): # SamModel class SamModelPolicy(SamPolicy): - def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 92cbd3f72b83..74cc7337e9f1 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np from torch import Tensor, nn @@ -15,7 +15,6 @@ ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription -from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.t5 import ( T5PipelineForwards, @@ -30,7 +29,6 @@ class T5BasePolicy(Policy): - def config_sanity_check(self): pass @@ -65,151 +63,181 @@ def module_policy(self): warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]) - policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - policy[T5Attention] = ModulePolicyDescription(attribute_replacement={ - "d_model": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "n_heads": - self.model.config.num_heads // self.shard_config.tensor_parallel_size, - "inner_dim": - self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="o", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="relative_attention_bias", - target_module=Embedding1D, - kwargs=dict(gather_output=False), - ignore_if_not_exist=True) - ]) - policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]) - policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi_0 ", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wi_1", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wo", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) + policy[T5Stack] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + ] + ) + policy[T5LayerSelfAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerCrossAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + policy[T5Attention] = ModulePolicyDescription( + attribute_replacement={ + "d_model": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": self.model.config.num_heads + * self.model.config.d_kv + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True, + ), + ], + ) + policy[T5LayerFF] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseGatedActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0 ", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerSelfAttention) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerCrossAttention) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5Stack) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_t5_flash_attention_forward(), - }, - policy=policy, - target_key=T5Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_t5_flash_attention_forward(), + }, + policy=policy, + target_key=T5Attention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_T5_layer_ff_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerFF) - self.append_or_create_method_replacement(description={ - 'forward': get_T5_layer_self_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerSelfAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_T5_layer_cross_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerCrossAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_T5_layer_ff_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_self_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_cross_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerCrossAttention, + ) return policy @@ -217,8 +245,9 @@ def postprocess(self): return self.model @staticmethod - def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int, - num_stages: int) -> Tuple[List[int], int]: + def distribute_t5_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: """ Distribute t5 layers into stages when pipeline parallel is used. Return the layer distribution as a list and the starting stage of decoder. @@ -251,8 +280,9 @@ def objective(num_encoder_stages): return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod - def get_t5_stage_index(layers_per_stage: List[int], stage: int, - decoder_starting_stage: int) -> Tuple[bool, int, int]: + def get_t5_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder @@ -269,16 +299,18 @@ def get_held_layers(self) -> List[nn.Module]: model = self.model encoder = self.model.encoder - decoder = getattr(self.model, 'decoder', None) + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) if stage_manager.stage < decoder_starting_stage: # current stage is in t5's encoder @@ -303,47 +335,51 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager encoder = self.model.encoder - decoder = getattr(self.model, 'decoder', None) + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class T5ModelPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import T5Model + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=T5Model) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5Model, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy) @@ -356,9 +392,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), - len(module.decoder.block), - stage_manager.num_stages) + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] @@ -366,7 +402,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class T5ForConditionalGenerationPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() @@ -376,22 +411,26 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription(suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ], - policy=policy, - target_key=T5ForConditionalGeneration) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + ], + policy=policy, + target_key=T5ForConditionalGeneration, + ) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=T5ForConditionalGeneration, - new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=T5ForConditionalGeneration, + new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -404,9 +443,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), - len(module.decoder.block), - stage_manager.num_stages) + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) shared_params = [] shared_embedding = {} @@ -427,7 +466,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class T5EncoderPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() @@ -437,17 +475,19 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=T5EncoderModel) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5EncoderModel, + ) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=T5EncoderModel, - new_forward=T5PipelineForwards.t5_encoder_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=T5EncoderModel, new_forward=T5PipelineForwards.t5_encoder_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b4fb8692e684..270cdce9b091 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -16,11 +16,10 @@ ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] +__all__ = ["ViTPolicy", "ViTModelPolicy", "ViTForImageClassificationPolicy", "ViTForMaskedImageModelingPolicy"] class ViTPolicy(Policy): - def config_sanity_check(self): pass @@ -28,8 +27,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention policy = {} @@ -38,77 +36,85 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]) - - policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) + policy[ViTEmbeddings] = ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ], + ) + + policy[ViTLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_vit_flash_self_attention_forward(), - }, - policy=policy, - target_key=ViTSelfAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_vit_flash_self_attention_forward(), + }, + policy=policy, + target_key=ViTSelfAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_vit_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=ViTOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=ViTOutput, + ) return policy def new_model_class(self): @@ -121,7 +127,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" - if self.model.__class__.__name__ == 'ViTModel': + if self.model.__class__.__name__ == "ViTModel": module = self.model else: module = self.model.vit @@ -138,22 +144,21 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'ViTModel': + if self.model.__class__.__name__ == "ViTModel": module = self.model else: module = self.model.vit layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) # ViTModel class ViTModelPolicy(ViTPolicy): - def __init__(self) -> None: super().__init__() @@ -181,26 +186,29 @@ def get_held_layers(self) -> List[nn.Module]: # ViTForImageClassification class ViTForImageClassificationPolicy(ViTPolicy): - def module_policy(self): from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: new_item = { - ViTForImageClassification: - ModulePolicyDescription(sub_module_replacement=[ + ViTForImageClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) - self.set_pipeline_forward(model_cls=ViTForImageClassification, - pipeline_forward=ViTForImageClassification_pipeline_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy, + ) return policy @@ -219,7 +227,6 @@ def get_held_layers(self) -> List[nn.Module]: # ViTForMaskedImageModeling class ViTForMaskedImageModelingPolicy(ViTPolicy): - def __init__(self) -> None: super().__init__() @@ -230,9 +237,11 @@ def module_policy(self): if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) - self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling, - pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 31ba82166b31..d9af2461cdb8 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -8,7 +8,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( WhisperPipelineForwards, @@ -19,13 +18,14 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', - 'WhisperForAudioClassificationPolicy' + "WhisperPolicy", + "WhisperModelPolicy", + "WhisperForConditionalGenerationPolicy", + "WhisperForAudioClassificationPolicy", ] class WhisperPolicy(Policy): - def config_sanity_check(self): pass @@ -55,179 +55,197 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) - #TODO using the jit fused add_and_dropout affect the accuracy + # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: - policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) - - policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size, - "encoder_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "encoder_attn.num_heads": - self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) - - policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - ]) + policy[WhisperEncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.decoder_attention_heads + // self.shard_config.tensor_parallel_size, + "encoder_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "encoder_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ] + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle encoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperEncoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperEncoderLayer, + ) # Handle decoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperDecoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperDecoderLayer, + ) # handle encoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperEncoder) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoder, + ) # handle decoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperDecoder) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoder, + ) # enable flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_whisper_flash_attention_forward(), - }, - policy=policy, - target_key=WhisperAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=WhisperDecoderLayer) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=WhisperEncoderLayer) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_encoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer, + ) return policy @@ -236,10 +254,13 @@ def add_lm_head_policy(self, base_policy): # optimize for tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=WhisperForConditionalGeneration) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) return base_policy @@ -247,8 +268,9 @@ def postprocess(self): return self.model @staticmethod - def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, - num_stages: int) -> Tuple[List[int], int]: + def distribute_whisper_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: """ Distribute whisper layers into stages when pipeline parallel is used. Return the layer distribution as a list and the starting stage of decoder. @@ -281,8 +303,9 @@ def objective(num_encoder_stages): return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod - def get_whisper_stage_index(layers_per_stage: List[int], stage: int, - decoder_starting_stage: int) -> Tuple[bool, int, int]: + def get_whisper_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder @@ -293,13 +316,12 @@ def get_whisper_stage_index(layers_per_stage: List[int], stage: int, return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: - assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'WhisperModel': + if self.model.__class__.__name__ == "WhisperModel": model = self.model - elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": model = self.model.model else: model = None @@ -320,9 +342,11 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) if stage_manager.stage < decoder_starting_stage: # current stage is in whisper's encoder @@ -347,14 +371,14 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'WhisperModel': + if self.model.__class__.__name__ == "WhisperModel": model = self.model - elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": model = self.model.model else: model = None @@ -373,34 +397,37 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli num_decoder_layers = 0 layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_index = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) # WhisperModel class WhisperModelPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import WhisperModel + policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperModel, - new_forward=WhisperPipelineForwards.whisper_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy + ) return policy @@ -414,19 +441,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import WhisperForConditionalGeneration + policy = super().module_policy() policy = self.add_lm_head_policy(policy) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, - new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy, + ) return policy def postprocess(self): @@ -457,8 +486,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, - stage_manager.num_stages) + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) shared_params = [] shared_embedding = {} if id(module.proj_out) == id(model.decoder.embed_tokens): @@ -472,7 +502,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() @@ -481,12 +510,15 @@ def preprocess(self): def module_policy(self): from transformers import WhisperForAudioClassification + policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperForAudioClassification, - new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index 7abdd45ec7c5..acf8a95a41ca 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -2,4 +2,4 @@ from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer'] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0b6e1640952b..6935288130c9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,7 +6,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -__all__ = ['ShardConfig'] +__all__ = ["ShardConfig"] @dataclass @@ -45,7 +45,8 @@ def tensor_parallel_size(self): def __post_init__(self): if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: raise ValueError( - "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" + ) if not self.enable_sequence_parallelism and self.enable_sequence_overlap: raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 7592069a2dd9..1bed850c6581 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -12,7 +12,7 @@ from .shard_config import ShardConfig from .utils import set_tensors_to_none -__all__ = ['ModelSharder', 'shard_model'] +__all__ = ["ModelSharder", "shard_model"] class ModelSharder(object): @@ -64,13 +64,15 @@ def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None: param_replacement = module_description.param_replacement sub_module_replacement = module_description.sub_module_replacement method_replacement = module_description.method_replacement - self._recursive_replace_layer(self.model, - layer_cls, - attr_replacement, - param_replacement, - method_replacement, - sub_module_replacement, - include=include) + self._recursive_replace_layer( + self.model, + layer_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) def _recursive_replace_layer( self, @@ -94,8 +96,9 @@ def _recursive_replace_layer( sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ - if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ - (module.__class__ == origin_cls): + if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or ( + module.__class__ == origin_cls + ): if attr_replacement is not None: self._replace_attr(module, attr_replacement) @@ -109,13 +112,15 @@ def _recursive_replace_layer( self._replace_sub_module(module, sub_module_replacement, include) for name, child in module.named_children(): - self._recursive_replace_layer(child, - origin_cls, - attr_replacement, - param_replacement, - method_replacement, - sub_module_replacement, - include=include) + self._recursive_replace_layer( + child, + origin_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) def _replace_attr( self, @@ -153,10 +158,12 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla bound_method = MethodType(new_method, module) setattr(module, method_name, bound_method) - def _replace_sub_module(self, - org_layer: nn.Module, - sub_module_replacement: List[SubModuleReplacementDescription], - include: Optional[Set[nn.Module]] = None) -> None: + def _replace_sub_module( + self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, + ) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -170,7 +177,7 @@ def _replace_sub_module(self, target_module = description.target_module kwargs = {} if description.kwargs is None else description.kwargs - assert target_module is not None, 'target_module should not be None' + assert target_module is not None, "target_module should not be None" native_sub_module = getattr_(org_layer, suffix, ignore=True) @@ -178,8 +185,9 @@ def _replace_sub_module(self, if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): continue - assert not isinstance(native_sub_module, target_module), \ - f"The module with suffix {suffix} has been replaced, please check the policy" + assert not isinstance( + native_sub_module, target_module + ), f"The module with suffix {suffix} has been replaced, please check the policy" # if it is None and we are allowed to ignore this module # just skip @@ -187,9 +195,9 @@ def _replace_sub_module(self, continue try: - replace_layer = target_module.from_native_module(native_sub_module, - self.shard_config.tensor_parallel_process_group, - **kwargs) + replace_layer = target_module.from_native_module( + native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs + ) except Exception as e: raise RuntimeError( f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" @@ -200,7 +208,6 @@ def _replace_sub_module(self, setattr_(org_layer, suffix, replace_layer) def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: - def collect_sub_modules(module: nn.Module): if module is None: return diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 099376d931e8..9ed149f33f2f 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -5,7 +5,14 @@ from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ - 'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook', - 'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', - 'merge_same_dim_mesh_list' + "ColoTensor", + "convert_parameter", + "named_params_with_colotensor", + "ColoParameter", + "ColoParamOpHook", + "ColoParamOpHookManager", + "CommSpec", + "CollectiveCommPattern", + "convert_dim_partition_dict", + "merge_same_dim_mesh_list", ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 076661a08824..5712505ae2ff 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -11,7 +11,7 @@ def is_no_hook_op(func) -> bool: - return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS + return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS def filter_colo_parameters(*args, **kwargs): @@ -36,18 +36,16 @@ def get_colo_parameters(element) -> None: def replace_args(args, kwargs, new_args): - args = new_args[:len(args)] - for k, v in zip(kwargs.keys(), new_args[len(args):]): + args = new_args[: len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args) :]): kwargs[k] = v return tuple(args), kwargs class ColoParameter(ColoTensor, torch.nn.Parameter): - r"""A kind of ColoTensor to be considered as a module parameter. + r"""A kind of ColoTensor to be considered as a module parameter.""" - """ - - def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter': + def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> "ColoParameter": if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index a20a1444a406..c2de9abce371 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -7,7 +7,7 @@ torch.Tensor.add_: torch.Tensor.add, torch.Tensor.sub_: torch.Tensor.sub, torch.Tensor.mul_: torch.Tensor.mul, - torch.Tensor.div_: torch.Tensor.div + torch.Tensor.div_: torch.Tensor.div, } @@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]: Tensor._base.__get__, Tensor.grad.__get__, Tensor._grad.__get__, - Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor } @@ -37,17 +37,18 @@ def _convert_output(output, func): class ColoTensor(torch.Tensor): - """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. + """Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. It is only used to trigger the torch function hook. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. """ - torch_major = int(torch.__version__.split('.')[0]) - torch_minor = int(torch.__version__.split('.')[1]) - def __new__(cls, data: torch.Tensor) -> 'ColoTensor': + torch_major = int(torch.__version__.split(".")[0]) + torch_minor = int(torch.__version__.split(".")[1]) + + def __new__(cls, data: torch.Tensor) -> "ColoTensor": """ The signature of the __new__ has to be consistent with the torch.Tensor. @@ -74,7 +75,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # we have to capture the `backward` function # and make sure that it does not in `torch._C.DisableTorchFunction()` context if func is torch.Tensor.backward: - assert len(args) == 1 # only has 1 parameter + assert len(args) == 1 # only has 1 parameter backward_tensor = torch.Tensor(args[0]) tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} return backward_tensor.backward(**tensor_kwargs) @@ -83,8 +84,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if func in INPALCE_MAPPING: func = INPALCE_MAPPING[func] # set the 'inplace' kwargs to False - if 'inplace' in kwargs: - kwargs['inplace'] = False + if "inplace" in kwargs: + kwargs["inplace"] = False with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 204f81343199..de0cba26b52a 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -7,15 +7,15 @@ from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] def _all_gather(tensor, comm_spec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -31,9 +31,9 @@ def _all_gather(tensor, comm_spec): def _split(tensor, comm_spec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -45,9 +45,9 @@ def _split(tensor, comm_spec): def _all_to_all(tensor, comm_spec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) @@ -66,9 +66,9 @@ def _all_to_all(tensor, comm_spec): def _all_reduce(tensor, comm_spec, async_op=False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -79,7 +79,7 @@ def _all_reduce(tensor, comm_spec, async_op=False): def _mix_gather(tensor, comm_spec): - ''' + """ Implement mix gather operation on device mesh based on information provided by comm_spec. Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather @@ -124,7 +124,7 @@ def _mix_gather(tensor, comm_spec): leading_group_dim = 1 process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] - ''' + """ total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] @@ -155,15 +155,16 @@ def _mix_gather(tensor, comm_spec): torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) ] for i in range(cat_slice[1]): - tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), - comm_spec.gather_dim[0]).contiguous() + tmp_tensor_list[i] = torch.cat( + tuple(tensor_list[i * cat_slice[0] : (i + 1) * cat_slice[0]]), comm_spec.gather_dim[0] + ).contiguous() output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() return output def _mix_split(tensor, comm_spec): - ''' + """ Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. @@ -177,7 +178,7 @@ def _mix_split(tensor, comm_spec): # [[0, 1, 2, 3], # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' + """ mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim total_slices = comm_spec.device_mesh.shape[0] @@ -316,11 +317,13 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - sharding_spec=comm_spec.sharding_spec, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output @@ -330,7 +333,6 @@ def backward(ctx, grad_outputs): class _MixGatherForwardMixSplitBackward(torch.autograd.Function): - @staticmethod def symbolic(graph, input_): return _mix_gather(input_) @@ -370,16 +372,16 @@ def mixgather_forward_split_backward(input_, comm_spec): class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It has two main functions: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. @@ -393,16 +395,18 @@ class CommSpec: gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern, - sharding_spec, - gather_dim=None, - shard_dim=None, - logical_process_axis=None, - forward_only=False, - mix_gather=False): + """ + + def __init__( + self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False, + mix_gather=False, + ): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim @@ -449,14 +453,14 @@ def __repr__(self): res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"logical_process_asex:{self.logical_process_axes})") - return ''.join(res_list) + return "".join(res_list) def get_comm_cost(self): - ''' + """ For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is zero. - ''' + """ comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) cost_dict = {} if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: @@ -500,13 +504,13 @@ def get_comm_cost(self): return cost_dict def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 3ae38a12555b..fad5101d380c 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -21,8 +21,23 @@ from .sharding_spec import ShardingSpec __all__ = [ - 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', - 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', - 'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization', - 'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec' + "is_distributed_tensor", + "distribute_tensor", + "to_global", + "is_sharded", + "shard_rowwise", + "shard_colwise", + "sharded_tensor_to_param", + "compute_global_numel", + "get_sharding_spec", + "get_global_shape", + "get_device_mesh", + "redistribute", + "get_layout", + "is_customized_distributed_tensor", + "distribute_tensor_with_customization", + "to_global_for_customized_distributed_tensor", + "customized_distributed_tensor_to_param", + "Layout", + "ShardingSpec", ] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 9848e4ca423e..178bac428ea9 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool: Returns: bool: True if the tensor is sharded, False otherwise. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) @@ -77,8 +77,10 @@ def new_clone(self, *args, **kwargs): return dtensor -def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: - ''' +def _construct_default_sharding_spec( + tensor: torch.Tensor, +) -> ShardingSpec: + """ Construct the default sharding specification for the tensor. Args: @@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: Returns: A `ShardingSpec` object without any sharding specified. - ''' + """ return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) def _apply_layout(tensor, layout): - ''' + """ Apply the layout to the local tensor during initializing process. - ''' + """ # layout converter requires a source and target laytout # we construct the source layer for an unsharded tensor # and use self.dist_layer as the targer layout for the sharded tensor @@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp Returns: torch.Tensor: The distributed tensor. """ - assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) # shard tensor @@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: - ''' + """ Convert the layout of the tensor from source_spec to target_spec. This will update the `local_tensor` and `dist_layout` in place. @@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: dtensor (torch.Tensor): the distributed tensor to be converted. device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. target_layout (Layout): the target layout specification. - ''' - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." global_shape = get_global_shape(dtensor) target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) - resharded_tensor = layout_converter.apply(tensor=dtensor, - source_layout=dtensor.dist_layout, - target_layout=target_layout) + resharded_tensor = layout_converter.apply( + tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout + ) return resharded_tensor @@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: the global tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(dtensor.dim(), {}) @@ -193,7 +195,7 @@ def shard_rowwise( if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: - assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) @@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: - assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) @@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) # make it distributed as well @@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param.data = dtensor # make it distributed as well param.dist_layout = dtensor.dist_layout @@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int: Returns: int: The global number of elements in the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." numel = reduce(operator.mul, dtensor.dist_layout.global_shape) return numel @@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout: Layout: The layout of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout @@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size: Returns: torch.Size: The global shape of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.global_shape @@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: Returns: DeviceMesh: The device mesh of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.device_mesh @@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: Returns: ShardingSpec: The sharding spec of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.sharding_spec @@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor): Returns: bool: Whether the given tensor is a customized distributed tensor. """ - return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn') + return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn") def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: @@ -402,9 +404,9 @@ def gather_fn(tensor): Returns: torch.Tensor: The distributed tensor. """ - assert callable(shard_fn), 'The shard_fn must be callable.' - assert callable(gather_fn), 'The gather_fn must be callable.' - assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + assert callable(shard_fn), "The shard_fn must be callable." + assert callable(gather_fn), "The gather_fn must be callable." + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." sharded_tensor = shard_fn(tensor) @@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch. Returns: torch.Tensor: The global tensor. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." return dtensor.gather_fn(dtensor) @@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: """ Convert the given customized distributed tensor to a parameter. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) @@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param """ Convert the given customized distributed tensor to an existing parameter. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param.data = dtensor.data param.shard_fn = dtensor.shard_fn diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 6158d0bfe2ad..8f5b52aab8f8 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -6,22 +6,22 @@ from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the communication method, process_group_dict to determine the process groups, gather_dim and shard_dim @@ -33,14 +33,16 @@ class CommSpec: gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern: CollectiveCommPattern, - process_group_dict: Dict, - gather_dim: int = None, - shard_dim: int = None, - logical_process_axis: int = None): + """ + + def __init__( + self, + comm_pattern: CollectiveCommPattern, + process_group_dict: Dict, + gather_dim: int = None, + shard_dim: int = None, + logical_process_axis: int = None, + ): self.comm_pattern = comm_pattern self.gather_dim = gather_dim self.shard_dim = shard_dim @@ -71,16 +73,16 @@ def __repr__(self): res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") - return ''.join(res_list) + return "".join(res_list) def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: @@ -89,9 +91,9 @@ def covert_spec_to_action(self, tensor): def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] @@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): def _split(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) @@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec): def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) new_shape = list(tensor.shape) @@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] if not tensor.is_contiguous(): tensor = tensor.contiguous() @@ -256,11 +258,13 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_group_dict=comm_spec.process_group_dict, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + process_group_dict=comm_spec.process_group_dict, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index a35b2f43e44b..6d4c5dbe3c09 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -25,15 +25,16 @@ def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_ self._sanity_check() def __hash__(self) -> int: - return hash(f'{self.sharding_spec}') + return hash(f"{self.sharding_spec}") def get_sharded_shape_per_device(self): sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) @@ -49,7 +50,8 @@ def _sanity_check(self): dim_check_list.remove(element) else: raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): @@ -61,5 +63,5 @@ def _sanity_check(self): if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 528ed7901c4f..e031e0472b0b 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -14,7 +14,7 @@ from .sharding_spec import ShardingSpec from .utils import get_comm_cost -__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options'] +__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"] @dataclass @@ -22,8 +22,8 @@ class LayoutConverterOptions: """ LayoutConverterOptions is a dataclass which specifies the preferences for layout converting. """ + # TODO: layout converter option is not implemented yet - pass def set_layout_converting_options(options: LayoutConverterOptions): @@ -63,7 +63,7 @@ def forward_only(self, value): self._forward_only = value def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-gather operation. For the all-gather operation, we just care about the S dimension. @@ -96,7 +96,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co Output: [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0) [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec @@ -125,16 +125,19 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co comm_pattern, process_group_dict=process_group_dict, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, - logical_process_axis=logical_process_axis) + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -142,7 +145,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co return valid_spec_dict def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-to-all operation. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -176,7 +179,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1) [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0) [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD @@ -224,11 +227,13 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) @@ -246,9 +251,11 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -256,7 +263,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com return valid_spec_dict def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single shard operation. For the sharding operation, we just care about legal sharding dimensions. @@ -291,7 +298,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1) [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec @@ -326,26 +333,31 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: - new_sharding_spec = ShardingSpec(dim_size=source_spec.dims, - dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_sharding_spec = ShardingSpec( + dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict + ) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass return valid_spec_dict def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with one step transform. Note: @@ -358,16 +370,17 @@ def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, Return: valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.all_gather_transform_layouts(source_layout)) valid_spec_dict.update(self.all_to_all_transform_layout(source_layout)) valid_spec_dict.update(self.shard_transform_layout(source_layout)) return valid_spec_dict - def layout_converting(self, source_layout: Layout, - target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]: - ''' + def layout_converting( + self, source_layout: Layout, target_layout: Layout + ) -> Tuple[List[Layout], List[CommSpec], float]: + """ This method will find a path to transform source_layout to target_layout with a greedy algorithm. The basic idea is: @@ -419,7 +432,7 @@ def layout_converting(self, source_layout: Layout, output: [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R] - ''' + """ source_spec = source_layout.sharding_spec target_spec = target_layout.sharding_spec MAX_TRANSFORM_STEPS = 20 @@ -470,11 +483,11 @@ def layout_converting(self, source_layout: Layout, raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]: - ''' + """ Get the total communication cost of the layout converting process. - ''' + """ transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout) - total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0} + total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0} for layout, comm_spec in zip(transform_path, comm_action_sequence): cost_dict = get_comm_cost(layout, comm_spec, self.forward_only) for key in total_cost: @@ -482,7 +495,7 @@ def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> D return total_cost def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor: - ''' + """ Apply target_layout to tensor with source layout, the transform path is generated by the layout_converting method. @@ -542,7 +555,7 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 565012b58a03..2ac0ca73e4b8 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -4,16 +4,16 @@ from ..utils import merge_same_dim_mesh_list from .misc import ShardingOutOfIndexError -__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -21,7 +21,7 @@ class DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -33,41 +33,40 @@ def __eq__(self, other): def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' + """ Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' + """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -77,14 +76,17 @@ def build_difference_2d_dict(self): difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -115,7 +117,7 @@ def build_difference_2d_dict(self): self.difference_dict = difference_dict def dim_diff(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -131,13 +133,13 @@ def dim_diff(self, other): Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference class ShardingSpec: - ''' + """ Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like [R, R, S0, S1], which means @@ -145,23 +147,27 @@ class ShardingSpec: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - dim_size: int, - dim_partition_dict: Dict[int, List[int]] = None, - sharding_sequence: List[DimSpec] = None): + def __init__( + self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None + ): self.dims = dim_size self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims, - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=self.dims, dim_partition_dict=self.dim_partition_dict + ) self.sharding_sequence = self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.dim_partition_dict = self.convert_shard_sequence_to_dict() self._sanity_check() @@ -169,31 +175,32 @@ def __init__(self, def _sanity_check(self): if len(self.sharding_sequence) > self.dims: raise ShardingOutOfIndexError( - f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.') + f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}." + ) if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims: raise ShardingOutOfIndexError( - f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.' + f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}." ) def __repr__(self): res_list = ["ShardingSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - return ' '.join(res_list) + return " ".join(res_list) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [DimSpec([])] * self.dims for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = DimSpec(shard_list) return sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -203,7 +210,7 @@ def convert_shard_sequence_to_dict(self): return new_dim_partition_dict def spec_diff(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -228,9 +235,10 @@ def spec_diff(self, other): Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.dim_diff(other_dim_spec) diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index fc22b990d879..8f0081246fb3 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -7,7 +7,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]: - ''' + """ This method is used to compute the communication cost for a given layout and comm_spec. For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to @@ -18,7 +18,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals comm_spec: the comm_spec to instruct the communication operation. forward_only: if it is True, we will just count the forward communication cost. If it is False, we will count both forward and backward communication cost. - ''' + """ comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1) device_mesh = layout.device_mesh comm_pattern = comm_spec.comm_pattern diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index e37859bac0c3..1fe99cd89a4e 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -36,6 +36,7 @@ class ColoParamOpHookManager: Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ + hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @@ -99,7 +100,6 @@ def has_hook() -> bool: class PreFwdPostBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, *args): ctx.params = params @@ -112,7 +112,6 @@ def backward(ctx, *grads): class PostFwdPreBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, args): ctx.params = params diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index b837333a2388..409561b3a26b 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -13,7 +13,7 @@ from .comm_spec import * -__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] +__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"] @dataclass @@ -21,16 +21,17 @@ class ShapeConsistencyOptions: """ ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. """ + # TODO: shape consistency option is not implemented yet - pass def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: shape_consistency_manager = ShapeConsistencyManager() global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) with torch.no_grad(): - global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, - global_sharding_spec) + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( + distributed_tensor, sharding_spec, global_sharding_spec + ) return global_tensor @@ -43,7 +44,6 @@ def set_shape_consistency_options(options: ShapeConsistencyOptions): class ShapeConsistencyManager(metaclass=SingletonMeta): - def __init__(self): self._options = None self._forward_only = False @@ -69,9 +69,10 @@ def forward_only(self, value): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-gather operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. @@ -99,7 +100,7 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD for target_pair in source_spec.dim_partition_dict.items(): @@ -121,19 +122,20 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -141,9 +143,10 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, pass return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_to_all_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -173,7 +176,7 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD tensor_dims = len(source_spec.entire_shape) @@ -214,12 +217,14 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() @@ -238,9 +243,9 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -250,7 +255,7 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, return valid_spec_dict def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): - ''' + """ Get all valid sharding specs from source_spec with single shard operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. @@ -280,7 +285,7 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD @@ -308,21 +313,23 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -330,14 +337,15 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): pass return valid_spec_dict - def get_all_mix_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_mix_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ S0S1 -> RR S1S0 -> RR S01R -> RR RS01 -> RR - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD tensor_dims = len(source_spec.entire_shape) @@ -362,19 +370,21 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, b_target_pair = (b_index, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=self.forward_only, - mix_gather=True) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True, + ) cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = {} # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -384,7 +394,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: - ''' + """ Get all valid sharding specs from source_spec with one step transform, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. Note: @@ -398,7 +408,7 @@ def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_d Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) @@ -545,18 +555,22 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, - fwd_peak_numel) if idx == 0 else fwd_action( - comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = ( + fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) + if idx == 0 + else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + ) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, - bwd_peak_numel) if idx == 0 else bwd_action( - comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = ( + bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel) + if idx == 0 + else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + ) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) @@ -564,9 +578,10 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int return TrainCycleItem(fwd_mem, bwd_mem, total_mem) - def shape_consistency(self, source_spec: ShardingSpec, - target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: - ''' + def shape_consistency( + self, source_spec: ShardingSpec, target_spec: ShardingSpec + ) -> Tuple[List[ShardingSpec], List[CommSpec], float]: + """ This method will find a path to transform source_spec to target_spec with a greedy algorithm. The basic idea is: @@ -623,9 +638,9 @@ def shape_consistency(self, source_spec: ShardingSpec, CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 - ''' + """ MAX_TRANSFORM_STEPS = 20 - total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} + total_cost_dict = {"forward": 0, "backward": 0, "total": 0} total_steps = 0 transform_path = [] comm_action_sequence = [] @@ -672,7 +687,7 @@ def shape_consistency(self, source_spec: ShardingSpec, raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: - ''' + """ Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. @@ -729,7 +744,7 @@ def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSp [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) for comm_spec in comm_action_sequence: tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index e594fd297dc4..b78ef6d97dd4 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -8,16 +8,16 @@ from .utils import merge_same_dim_mesh_list -__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["_DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class _DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -25,7 +25,7 @@ class _DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -37,41 +37,40 @@ def __eq__(self, other): def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' + """ Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' + """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -81,14 +80,17 @@ def build_difference_2d_dict(self): difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -119,7 +121,7 @@ def build_difference_2d_dict(self): self.difference_dict = difference_dict def difference(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -135,7 +137,7 @@ def difference(self, other): Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference @@ -157,7 +159,7 @@ class ShardingNotDivisibleError(ShardingSpecException): class ShardingSpec: - ''' + """ Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. @@ -168,13 +170,11 @@ class ShardingSpec: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - device_mesh: DeviceMesh, - entire_shape: torch.Size, - dim_partition_dict=None, - sharding_sequence=None): + def __init__( + self, device_mesh: DeviceMesh, entire_shape: torch.Size, dim_partition_dict=None, sharding_sequence=None + ): self.device_mesh = device_mesh if isinstance(entire_shape, (list, tuple)): @@ -183,12 +183,17 @@ def __init__(self, self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=len(entire_shape), dim_partition_dict=self.dim_partition_dict + ) self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.convert_shard_sequence_to_dict() self._sanity_check() @@ -196,7 +201,7 @@ def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") - return ' '.join(res_list) + return " ".join(res_list) def _sanity_check(self): # make sure all axes in logical device mesh only be used once @@ -207,7 +212,8 @@ def _sanity_check(self): dim_check_list.remove(element) else: raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the dimension is not out of index for dim in self.dim_partition_dict.keys(): @@ -226,22 +232,22 @@ def _sanity_check(self): if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) self.sharding_sequence = sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -251,7 +257,7 @@ def convert_shard_sequence_to_dict(self): self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -276,21 +282,22 @@ def sharding_sequence_difference(self, other): Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.difference(other_dim_spec) return difference def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index e7d51d099e02..19dde8febf84 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -7,7 +7,7 @@ def all_gather_simulator(target_pair): - ''' + """ Simulating all-gather operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -19,7 +19,7 @@ def all_gather_simulator(target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, shard_list = target_pair new_shard_list = shard_list[:-1] @@ -27,7 +27,7 @@ def all_gather_simulator(target_pair): def all_to_all_simulator(f_target_pair, b_target_pair): - ''' + """ Simulating all-to-all operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -47,7 +47,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, f_shard_list = f_target_pair _, b_shard_list = b_target_pair if not len(b_shard_list): @@ -61,7 +61,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): def shard_simulator(target_pair, legal_sharding_dims): - ''' + """ Simulating shard operation, analyze the communication cost(always ZERO) and simulate the influence of the DimSpec. @@ -78,7 +78,7 @@ def shard_simulator(target_pair, legal_sharding_dims): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, shard_list = target_pair shard_list_list = [] for dim in legal_sharding_dims: @@ -91,7 +91,7 @@ def shard_simulator(target_pair, legal_sharding_dims): def mix_gather_simulator(f_target_pair, b_target_pair): - ''' + """ Assume index of f and b target pairs are 'f' and 'b' S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) @@ -99,7 +99,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) - ''' + """ if f_target_pair[1] and b_target_pair[1]: leading_dim = b_target_pair[1] > f_target_pair[1] return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] @@ -118,7 +118,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: r"""Returns an iterator over module parameters (together with the @@ -154,7 +154,7 @@ def named_params_with_colotensor( for name, val in vars(mod).items(): if isinstance(val, ColoTensor) and val not in memo: memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val # find all nn.Parameters @@ -169,15 +169,16 @@ def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: def convert_parameter(module: torch.nn.Module, param_name: str): # Perform some validation first. if not hasattr(module, param_name): - raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + raise ValueError(f"module: {module} does not have parameter with name: {param_name}") tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( - f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') + raise ValueError(f"param: {param_name} is not a contiguous Tensor") st = _convert_tensor(tensor) @@ -193,9 +194,9 @@ def convert_parameter(module: torch.nn.Module, param_name: str): def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to convert the negative dim value to positive. - ''' + """ dims_to_convert = [] for dim, mesh_list in dim_partition_dict.items(): if dim < 0: @@ -207,13 +208,13 @@ def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to merge the different key value which points to same physical position. For example: dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. In this method, above dim_partition_dict will be converted to {1: [0, 1]} - ''' + """ converted_dim_partition_dict = {} for dim, mesh_list in dim_partition_dict.items(): if dim < 0: diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 0db33361c6a0..c6956e81fbde 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -19,7 +19,19 @@ ) __all__ = [ - 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close' + "assert_equal", + "assert_not_equal", + "assert_close", + "assert_close_loose", + "assert_equal_in_group", + "parameterize", + "rerun_on_exception", + "rerun_if_address_is_in_use", + "skip_if_not_enough_gpus", + "free_port", + "spawn", + "clear_cache_before_run", + "run_on_environment_flag", + "check_state_dict_equal", + "assert_hf_output_close", ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 8d9ec8ab5f35..816bc0d7b6d7 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -9,20 +9,22 @@ def assert_equal(a: Tensor, b: Tensor): - assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected a and b to be equal but they are not, {a} vs {b}" def assert_not_equal(a: Tensor, b: Tensor): - assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + assert not torch.all(a == b), f"expected a and b to be not equal but they are, {a} vs {b}" def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, - b, - rtol=rtol, - atol=atol, - msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}") + assert_close( + a, + b, + rtol=rtol, + atol=atol, + msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ + dtype: {a.dtype} vs {b.dtype}", + ) def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -35,12 +37,13 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): - assert len(list(d1.keys())) == len(list(d2.keys())), \ - f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" + assert len(list(d1.keys())) == len( + list(d2.keys()) + ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" for k, v1 in d1.items(): assert k in d2 v2 = d2[k] @@ -86,12 +89,9 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic assert v1 == v2, f"{v1} not equals to {v2}" -def assert_hf_output_close(out1: Any, - out2: Any, - ignore_keys: List[str] = None, - track_name: str = "", - atol=1e-5, - rtol=1e-5): +def assert_hf_output_close( + out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 +): """ Check if two outputs from huggingface are equal. @@ -108,23 +108,17 @@ def assert_hf_output_close(out1: Any, for k in out1.keys(): if ignore_keys is not None and k in ignore_keys: continue - assert_hf_output_close(out1[k], - out2[k], - track_name=f"{track_name}.{k}", - ignore_keys=ignore_keys, - atol=atol, - rtol=rtol) + assert_hf_output_close( + out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): # if two values are list # we recursively check the elements assert len(out1) == len(out2) for i in range(len(out1)): - assert_hf_output_close(out1[i], - out2[i], - track_name=f"{track_name}.{i}", - ignore_keys=ignore_keys, - atol=atol, - rtol=rtol) + assert_hf_output_close( + out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) elif isinstance(out1, Tensor) and isinstance(out2, Tensor): if out1.shape != out2.shape: raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index 6a80e1dcc548..b1e82b469c96 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -33,13 +33,14 @@ def test_for_something(): import pytest except ImportError: raise ImportError( - 'This function requires `pytest` to be installed, please do `pip install pytest` and try again.') + "This function requires `pytest` to be installed, please do `pip install pytest` and try again." + ) assert isinstance(name, str) - flag = os.environ.get(name.upper(), '0') + flag = os.environ.get(name.upper(), "0") - reason = f'Environment variable {name} is {flag}' - if flag == '1': + reason = f"Environment variable {name} is {flag}" + if flag == "1": return pytest.mark.skipif(False, reason=reason) else: return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py index ad6d24a4b94b..4525dff3fe80 100644 --- a/colossalai/testing/random.py +++ b/colossalai/testing/random.py @@ -11,7 +11,7 @@ def seed_all(seed, cuda_deterministic=False): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - if cuda_deterministic: # slower, more reproducible + if cuda_deterministic: # slower, more reproducible torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index a4370a8d4933..fdbda9a598bf 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -55,7 +55,6 @@ def say_something(person, msg): """ def _wrapper(func): - def _execute_function_by_param(**kwargs): for val in values: arg_map = {argument: val} @@ -120,11 +119,11 @@ def _match_lines(lines, pattern): return False def _wrapper(func): - def _run_until_success(*args, **kwargs): try_count = 0 - assert max_try is None or isinstance(max_try, int), \ - f'Expected max_try to be None or int, but got {type(max_try)}' + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" while max_try is None or try_count < max_try: try: @@ -132,14 +131,14 @@ def _run_until_success(*args, **kwargs): ret = func(*args, **kwargs) return ret except exception_type as e: - error_lines = str(e).split('\n') + error_lines = str(e).split("\n") if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): - print('Exception is caught, retrying...') + print("Exception is caught, retrying...") # when pattern is not specified, we always skip the exception # when pattern is specified, we only skip when pattern is matched continue else: - print('Maximum number of attempts is reached or pattern is not matched, no more retrying...') + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") raise e # Override signature @@ -198,7 +197,6 @@ def test_something(): """ def _wrap_func(f): - def _execute_by_gpu_num(*args, **kwargs): num_avail_gpu = torch.cuda.device_count() if num_avail_gpu >= min_gpus: @@ -263,7 +261,6 @@ def test_something(): """ def _wrap_func(f): - def _clear_cache(*args, **kwargs): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 5226f688b43b..3ec39b949a23 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -13,20 +13,20 @@ from .timer import MultiTimer, Timer __all__ = [ - 'conditional_context', - 'get_current_device', - 'synchronize', - 'empty_cache', - 'set_to_cuda', - 'Timer', - 'MultiTimer', - 'multi_tensor_applier', - 'TensorDetector', - 'ensure_path_exists', - 'disposable', - '_cast_float', - 'free_storage', - 'set_seed', - 'is_ddp_ignored', - 'set_device', + "conditional_context", + "get_current_device", + "synchronize", + "empty_cache", + "set_to_cuda", + "Timer", + "MultiTimer", + "multi_tensor_applier", + "TensorDetector", + "ensure_path_exists", + "disposable", + "_cast_float", + "free_storage", + "set_seed", + "is_ddp_ignored", + "set_device", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 8c769c5b13c0..c43caaff4806 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -28,7 +28,7 @@ def conditional_context(context_manager, enable=True): def is_ddp_ignored(p): - return getattr(p, '_ddp_to_ignore', False) + return getattr(p, "_ddp_to_ignore", False) def disposable(func: Callable) -> Callable: diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index 6b5d17cf04e7..6bfb08d1f04a 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -29,9 +29,9 @@ def get_current_device() -> torch.device: If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): - return torch.device(f'cuda:{torch.cuda.current_device()}') + return torch.device(f"cuda:{torch.cuda.current_device()}") else: - return torch.device('cpu') + return torch.device("cpu") def synchronize(): diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index 21bc530934d3..4eee4fbc0eee 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -27,19 +27,18 @@ def call_to_str(base, *args, **kwargs): Returns: str: A string representation of base(*args, **kwargs) """ - name = f'{base}(' + name = f"{base}(" if args: - name += ', '.join(repr(arg) for arg in args) + name += ", ".join(repr(arg) for arg in args) if kwargs: - name += ', ' + name += ", " if kwargs: - name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) - name += ')' + name += ", ".join(f"{key}={repr(arg)}" for key, arg in kwargs.items()) + name += ")" return name class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, default_dtype: Optional[torch.dtype] = None): self._old_default_dtype = None self._default_dtype = default_dtype @@ -53,7 +52,6 @@ def __enter__(self): torch.set_default_dtype(self._default_dtype) def preprocess_after(f): - @functools.wraps(f) def wrapper(module: torch.nn.Module, *args, **kwargs): f(module, *args, **kwargs) @@ -74,7 +72,7 @@ def _init_subclass(cls, **kwargs): substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) # holding on to the current __init__subclass__ for exit - torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) @@ -82,12 +80,11 @@ def _init_subclass(cls, **kwargs): return self def __exit__(self, exc_type, exc_value, traceback): - if self._default_dtype is not None: torch.set_default_dtype(self._old_default_dtype) def _disable_class(cls): - if not hasattr(cls, '_old_init'): + if not hasattr(cls, "_old_init"): raise AttributeError( f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." ) @@ -97,7 +94,7 @@ def _disable_class(cls): substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass self._post_context_exec() # Now that we cleaned up the metaclass injection, raise the exception. diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 6456dfb905b0..1b75448bdd3c 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -19,8 +19,8 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] """ epsize_param_dict = dict() for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters + if not hasattr(param, "moe_info"): + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = param.moe_info.ep_size if ep_size not in epsize_param_dict: @@ -37,7 +37,6 @@ def sync_moe_model_param(model: nn.Module): model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ if is_using_ddp(): - param_dict = get_moe_epsize_param_dict(model) # synchronize the parameters whose dp_group is the whole world diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index 2b6de5fe1f3c..750c2a32da34 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -25,7 +25,9 @@ def check_avail(self): raise RuntimeError( "Attempted to call MultiTensorApply method, but MultiTensorApply " "is not available, possibly because Apex was installed without " - "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err) + "--cpp_ext --cuda_ext. Original import error message:", + MultiTensorApply.import_err, + ) def __call__(self, op, noop_flag_buffer, tensor_lists, *args): self.check_avail() diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md index da8a6039d543..cad6c1fddd71 100644 --- a/colossalai/utils/rank_recorder/README.md +++ b/colossalai/utils/rank_recorder/README.md @@ -1,7 +1,7 @@ # Rank Recorder This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily. -Before using the tool, you should ensure dist.is_initialized() return true before exit of program. +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. ## Usage @@ -58,10 +58,10 @@ def worker(rank): with recorder("calc_1(x100)", rank) as r: calc(100, 100) - + with recorder("calc_2(x400)", rank) as r: calc(400, 400) - + with recorder("calc_2(x200)", rank) as r: calc(200, 200) @@ -69,4 +69,4 @@ if __name__ == "__main__": mp.spawn(worker, nprocs=WORLD_SIZE) ``` -run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py index 1274d0e7dbc5..1d347075a8ce 100644 --- a/colossalai/utils/rank_recorder/__init__.py +++ b/colossalai/utils/rank_recorder/__init__.py @@ -1,3 +1,3 @@ from colossalai.utils.rank_recorder.rank_recorder import recorder -__all__ = ["recorder"] \ No newline at end of file +__all__ = ["recorder"] diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py index 40bb7e184a12..1cb9169125a1 100644 --- a/colossalai/utils/rank_recorder/rank_recorder.py +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -1,18 +1,15 @@ -import time -from typing import List, Dict +import atexit import json import os -import time import shutil -import atexit +import time +from typing import Dict, List +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import torch import torch.distributed as dist -import json -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors - cmap = list(mcolors.TABLEAU_COLORS.values()) LOG_FOLDER = "record.log" @@ -20,7 +17,6 @@ class Event: - def __init__(self, start: int, end: int, name: str, rank: int) -> None: self.start = start self.end = end @@ -29,16 +25,15 @@ def __init__(self, start: int, end: int, name: str, rank: int) -> None: class Recorder: - def __init__(self) -> None: self.rank_to_history: Dict[int, List[Event]] = {} self.base_time = time.time() self.temp_event = None - self.export_format = 'png' - self.export_name = 'test' + self.export_format = "png" + self.export_name = "test" self.dpi = 500 - self.theme = 'dark_background' + self.theme = "dark_background" self.figure_width = 30 self.figure_height = 10 self.legend_fontsize = 16 @@ -84,18 +79,18 @@ def __exit__(self, *args): def dump_record(self): rank = dist.get_rank() rank_to_history = self.rank_to_history - records = {'base_time': self.base_time, 'content': {}} + records = {"base_time": self.base_time, "content": {}} for record_rank in rank_to_history: history = rank_to_history[record_rank] recs = [] for event in history: - rec = {'start': event.start, 'end': event.end, 'name': event.name} + rec = {"start": event.start, "end": event.end, "name": event.name} recs.append(rec) - records['content'][record_rank] = recs + records["content"][record_rank] = recs - dump_name = f'{rank}.json' + dump_name = f"{rank}.json" dump_path = os.path.join(LOG_FOLDER, dump_name) - with open(dump_path, 'w', encoding='utf-8') as f: + with open(dump_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False) def merge_recode(self): @@ -117,24 +112,22 @@ def merge_recode(self): logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] recoders = {} for path in logs_path: - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: recs = json.load(f) - for record_rank in recs['content']: - history = recs['content'][record_rank] + for record_rank in recs["content"]: + history = recs["content"][record_rank] recoders[record_rank] = [] for rec in history: - recoders[record_rank].append({ - 'start': rec['start'] - base_time, - 'end': rec['end'] - base_time, - 'name': rec['name'] - }) + recoders[record_rank].append( + {"start": rec["start"] - base_time, "end": rec["end"] - base_time, "name": rec["name"]} + ) shutil.rmtree(LOG_FOLDER) - with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + with open(self.export_name + ".json", "w", encoding="utf-8") as f: json.dump(recoders, f, ensure_ascii=False) def visualize_record(self): - with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + with open(self.export_name + ".json", "r", encoding="utf-8") as f: records = json.load(f) records = dict(records) ranks = list(sorted(records.keys())) @@ -147,9 +140,9 @@ def visualize_record(self): for rank in ranks: rank_records = records[rank] for rec in rank_records: - s = rec['start'] - e = rec['end'] - name = rec['name'] + s = rec["start"] + e = rec["end"] + name = rec["name"] if name not in name_list: name_list[name] = len(name_list) bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) @@ -157,8 +150,8 @@ def visualize_record(self): plots[name] = bar plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) - plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) - plt.grid(axis='x') + plt.yticks(ticks=ranks, labels=[f"Device:{rank}" for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis="x") plt.savefig("{}.{}".format(self.export_name, self.export_format)) def exit_worker(self): diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py index cafc19b67c5c..c6c68aa4009b 100644 --- a/colossalai/utils/tensor_detector/__init__.py +++ b/colossalai/utils/tensor_detector/__init__.py @@ -1 +1 @@ -from .tensor_detector import TensorDetector +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md index d6852ea55b54..455eae18116a 100644 --- a/colossalai/utils/tensor_detector/readme.md +++ b/colossalai/utils/tensor_detector/readme.md @@ -14,7 +14,7 @@ class MLP(nn.Module): super().__init__() self.mlp = nn.Sequential(nn.Linear(64, 8), nn.ReLU(), - nn.Linear(8, 32)) + nn.Linear(8, 32)) def forward(self, x): return self.mlp(x) ``` @@ -125,4 +125,3 @@ Total GPU Memory Allocated on cuda:0 is 14.0 KB This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py and https://github.com/Oldpan/Pytorch-Memory-Utils - diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py index cfcd4e47b4cb..38cf094b8dd0 100644 --- a/colossalai/utils/tensor_detector/tensor_detector.py +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -1,21 +1,19 @@ import gc import inspect +from collections import defaultdict +from typing import Optional + import torch import torch.nn as nn -from typing import Optional -from collections import defaultdict LINE_WIDTH = 108 -LINE = '-' * LINE_WIDTH + '\n' - +LINE = "-" * LINE_WIDTH + "\n" -class TensorDetector(): - def __init__(self, - show_info: bool = True, - log: str = None, - include_cpu: bool = False, - module: Optional[nn.Module] = None): +class TensorDetector: + def __init__( + self, show_info: bool = True, log: str = None, include_cpu: bool = False, module: Optional[nn.Module] = None + ): """This class is a detector to detect tensor on different devices. Args: @@ -57,40 +55,39 @@ def get_tensor_mem(self, tensor): def mem_format(self, real_memory_size): # format the tensor memory into a reasonable magnitude if real_memory_size >= 2**30: - return str(real_memory_size / (2**30)) + ' GB' + return str(real_memory_size / (2**30)) + " GB" if real_memory_size >= 2**20: - return str(real_memory_size / (2**20)) + ' MB' + return str(real_memory_size / (2**20)) + " MB" if real_memory_size >= 2**10: - return str(real_memory_size / (2**10)) + ' KB' - return str(real_memory_size) + ' B' + return str(real_memory_size / (2**10)) + " KB" + return str(real_memory_size) + " B" def collect_tensors_state(self): for obj in gc.get_objects(): if torch.is_tensor(obj): # skip cpu tensor when include_cpu is false and the tensor we have collected before - if (not self.include_cpu) and obj.device == torch.device('cpu'): + if (not self.include_cpu) and obj.device == torch.device("cpu"): continue self.detected.append(id(obj)) # skip parameters we had added in __init__ when module is an instance of nn.Module for the first epoch if id(obj) not in self.tensor_info: - name = type(obj).__name__ # after backward, we want to update the records, to show you the change - if isinstance(self.module, nn.Module) and name == 'Parameter': + if isinstance(self.module, nn.Module) and name == "Parameter": if obj.grad is not None: # with grad attached for par_name, param in self.module.named_parameters(): if param.requires_grad and param.grad.equal(obj.grad): - name = par_name + ' (with grad)' + name = par_name + " (with grad)" else: # with no grad attached # there will be no new parameters created during running # so it must be in saved_tensor_info continue # we can also marked common tensors as tensor(with grad) - if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if name == "Tensor" and (obj.is_leaf or obj.retains_grad): if obj.grad is not None: - name = name + ' (with grad)' + name = name + " (with grad)" # in fact, common tensor have no grad # unless you set retain_grad() if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: @@ -111,10 +108,10 @@ def collect_tensors_state(self): self.devices.append(obj.device) def print_tensors_state(self): - template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + template_format = "{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}" self.info += LINE - self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') - self.info += '\n' + self.info += template_format.format(" ", "Tensor", "device", "shape", "grad", "dtype", "Mem") + self.info += "\n" self.info += LINE # if a tensor updates this turn, and was recorded before @@ -124,24 +121,30 @@ def print_tensors_state(self): minus = outdated + minus if len(self.order) > 0: for tensor_id in self.order: - self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), - str(self.tensor_info[tensor_id][1]), - str(tuple(self.tensor_info[tensor_id][2])), - str(self.tensor_info[tensor_id][3]), - str(self.tensor_info[tensor_id][4]), - str(self.tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "+", + str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5]), + ) + self.info += "\n" if len(self.order) > 0 and len(minus) > 0: - self.info += '\n' + self.info += "\n" if len(minus) > 0: for tensor_id in minus: - self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), - str(self.saved_tensor_info[tensor_id][1]), - str(tuple(self.saved_tensor_info[tensor_id][2])), - str(self.saved_tensor_info[tensor_id][3]), - str(self.saved_tensor_info[tensor_id][4]), - str(self.saved_tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "-", + str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5]), + ) + self.info += "\n" # deleted the updated tensor self.saved_tensor_info.pop(tensor_id) @@ -152,16 +155,16 @@ def print_tensors_state(self): self.info += LINE self.info += f"Detect Location: {locate_msg}\n" for device in self.devices: - if device == torch.device('cpu'): + if device == torch.device("cpu"): continue gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) self.info += f"Total GPU Memory Allocated on {device} is {gpu_mem_alloc}\n" self.info += LINE - self.info += '\n\n' + self.info += "\n\n" if self.show_info: print(self.info) if self.log is not None: - with open(self.log + '.log', 'a') as f: + with open(self.log + ".log", "a") as f: f.write(self.info) def detect(self, include_cpu=False): diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 4b61f4a5ef11..2f61817f0461 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -2,12 +2,12 @@ # -*- encoding: utf-8 -*- import time from typing import Tuple + from .cuda import synchronize class Timer: - """A timer object which helps to log the execution times, and provides different tools to assess the times. - """ + """A timer object which helps to log the execution times, and provides different tools to assess the times.""" def __init__(self): self._started = False @@ -25,16 +25,14 @@ def current_time(self) -> float: return time.time() def start(self): - """Firstly synchronize cuda, reset the clock and then start the timer. - """ + """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 synchronize() self._start_time = time.time() self._started = True def lap(self): - """lap time and return elapsed time - """ + """lap time and return elapsed time""" return self.current_time - self._start_time def stop(self, keep_in_history: bool = False): @@ -80,12 +78,11 @@ def get_elapsed_time(self): Note: Use it only when timer is not in progress """ - assert not self._started, 'Timer is still in progress' + assert not self._started, "Timer is still in progress" return self._elapsed def reset(self): - """Clear up the timer and its history - """ + """Clear up the timer and its history""" self._history = [] self._started = False self._elapsed = 0 diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 4991241b8df1..90d0f8de1916 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -10,6 +10,13 @@ from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ - 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', - 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' + "GeminiDDP", + "GeminiOptimizer", + "GeminiAdamOptimizer", + "zero_model_wrapper", + "zero_optim_wrapper", + "LowLevelZeroOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", + "get_static_torch_model", ] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py index 7ac6a9be4140..358d5c7fd289 100644 --- a/colossalai/zero/gemini/__init__.py +++ b/colossalai/zero/gemini/__init__.py @@ -6,6 +6,15 @@ from .utils import get_static_torch_model __all__ = [ - 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP', - 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' + "GeminiManager", + "TensorInfo", + "TensorState", + "ChunkManager", + "search_chunk_configuration", + "GeminiDDP", + "get_static_torch_model", + "GeminiAdamOptimizer", + "GeminiOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", ] diff --git a/colossalai/zero/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py index 6914d2dbef45..91906f68ad25 100644 --- a/colossalai/zero/gemini/chunk/__init__.py +++ b/colossalai/zero/gemini/chunk/__init__.py @@ -3,4 +3,4 @@ from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager -__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] +__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"] diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 3e7403adb53b..bbef9013c20b 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -17,12 +17,17 @@ class TensorState(Enum): READY_FOR_REDUCE = 4 -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, - TensorState.HOLD), - (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) +STATE_TRANS = ( + (TensorState.FREE, TensorState.HOLD), + (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), + (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), + (TensorState.READY_FOR_REDUCE, TensorState.HOLD), +) @dataclass @@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None: class Chunk: _total_number = 0 - def __init__(self, - chunk_size: int, - process_group: ProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - cpu_shard_init: bool = False, - keep_gathered: bool = False, - pin_memory: bool = False) -> None: + def __init__( + self, + chunk_size: int, + process_group: ProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False, + ) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors Here we use all-gather operation to gather the whole chunk. @@ -99,9 +106,9 @@ def __init__(self, device = init_device or get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA # cuda local chunk, which is sharded on GPUs self.cuda_shard = None @@ -134,7 +141,7 @@ def __init__(self, # they are treated the same as that of the parameters in DDP during training. self.keep_gathered = keep_gathered if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin + pin_memory = False # since this chunk is gathered, it doesn't need to pin # if pin_memory is True, we allocate a piece of CPU pin-memory # for it all the time @@ -160,7 +167,7 @@ def memory_usage(self) -> Dict[str, int]: if self.chunk_temp is not None: # this chunk is not closed - if self.chunk_temp.device.type == 'cuda': + if self.chunk_temp.device.type == "cuda": cuda_memory += self.chunk_mem else: cpu_memory += self.chunk_mem @@ -180,11 +187,11 @@ def device_type(self) -> str: return self.chunk_temp.device.type else: if self.is_gathered: - return 'cuda' + return "cuda" elif self.cuda_shard is not None: - return 'cuda' + return "cuda" else: - return 'cpu' + return "cpu" @property def payload(self) -> torch.Tensor: @@ -217,8 +224,10 @@ def can_release(self) -> bool: if self.keep_gathered: return False else: - return self.tensor_state_cnter[TensorState.HOLD] + \ - self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + return ( + self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] + == self.num_tensors + ) @property def can_reduce(self): @@ -226,27 +235,25 @@ def can_reduce(self): @property def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values on CUDA. - """ + """Check if the chunk has inf or nan values on CUDA.""" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # only check on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() def set_l2_norm(self) -> None: - """Record l2 norm of this chunks on CUDA. - """ + """Record l2 norm of this chunks on CUDA.""" assert self.l2_norm is None, "you are calculating the l2 norm twice" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # calculate on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] chunk_l2_norm = valid_tensor.data.float().norm(2) - self.l2_norm = chunk_l2_norm.item()**2 + self.l2_norm = chunk_l2_norm.item() ** 2 def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. @@ -263,9 +270,9 @@ def append_tensor(self, tensor: torch.Tensor): if new_utilized_size > self.chunk_size: raise ChunkFullError - self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten()) assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape) # record all the information about the tensor self.num_tensors += 1 @@ -275,8 +282,7 @@ def append_tensor(self, tensor: torch.Tensor): self.utilized_size = new_utilized_size def close_chunk(self): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - """ + """Close the chunk. Any tensor can't be appended to a closed chunk later.""" # sanity check assert self.chunk_temp is not None @@ -286,7 +292,7 @@ def close_chunk(self): elif self.utilized_size < self.shard_end: self.valid_end = self.utilized_size - self.shard_begin - if self.chunk_temp.device.type == 'cpu': + if self.chunk_temp.device.type == "cpu": self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) self.__update_tensors_ptr() else: @@ -298,12 +304,12 @@ def close_chunk(self): if self.keep_gathered: return - if self.pin_memory or self.shard_device.type == 'cpu': + if self.pin_memory or self.shard_device.type == "cpu": self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited + self.cpu_vis_flag = True # cpu_shard has been visited - if self.shard_device.type == 'cpu': + if self.shard_device.type == "cpu": self.cuda_shard = None def shard_move(self, device: torch.device, force_copy: bool = False): @@ -318,12 +324,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False): # when the current chunk is not synchronized with the optimizer # just use another way for the movement if not self.optim_sync_flag: - assert device.type == 'cuda', "each chunk should first be moved to CUDA" + assert device.type == "cuda", "each chunk should first be moved to CUDA" self.__paired_shard_move() self.optim_sync_flag = True return - if device.type == 'cuda': + if device.type == "cuda": assert device == get_current_device(), "can't move chunk to another device" if self.cuda_shard: @@ -333,7 +339,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if not self.pin_memory: self.cpu_shard = None - elif device.type == 'cpu': + elif device.type == "cpu": if self.cuda_shard is None: return @@ -350,8 +356,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): raise NotImplementedError def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. - """ + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -360,8 +365,7 @@ def access_chunk(self): self.__update_tensors_ptr() def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. - """ + """Release the usable chunk. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -369,8 +373,7 @@ def release_chunk(self): self.__scatter() def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. - """ + """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered @@ -423,20 +426,18 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten assert self.is_gathered tensor_info = self.tensors_info[tensor] - self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ + """Get the valid length of the chunk's payload.""" if self.keep_gathered: return self.utilized_size else: return self.valid_end - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. - """ + def init_pair(self, friend_chunk: "Chunk") -> None: + """Initialize the paired chunk.""" if self.paired_chunk is None and friend_chunk.paired_chunk is None: self.paired_chunk = friend_chunk friend_chunk.paired_chunk = self @@ -445,8 +446,7 @@ def init_pair(self, friend_chunk: 'Chunk') -> None: assert friend_chunk.paired_chunk is self def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. - """ + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.""" # sanity check assert self.paired_chunk is not None @@ -455,15 +455,15 @@ def optim_update(self) -> None: assert friend_chunk.is_gathered is True self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": self.cuda_shard.copy_(friend_chunk.cuda_shard) self.optim_sync_flag = True self.cpu_vis_flag = False else: # optim_sync_flag is set to False # see shard_move function for more details - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' + assert friend_chunk.device_type == "cpu" + assert self.device_type == "cpu" self.optim_sync_flag = False self.cpu_vis_flag = False @@ -492,7 +492,7 @@ def __scatter(self): self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) - self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end]) free_storage(self.cuda_global_chunk) self.is_gathered = False @@ -518,7 +518,7 @@ def __update_tensors_ptr(self) -> None: assert type(self.cuda_global_chunk) == torch.Tensor for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): self.tensor_state_cnter[tensor_info.state] -= 1 @@ -539,38 +539,41 @@ def __eq__(self, __o: object) -> bool: def __repr__(self, detailed: bool = True): output = [ "Chunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, - self.pg_size), + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( + self.chunk_size, self.dtype, self.pg_size + ), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( - self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size + ), ] - def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, - tensor.device)) + def print_tensor(tensor, prefix=""): + output.append( + "{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device) + ) if self.chunk_temp is not None: output.append("\tchunk temp:\n") - print_tensor(tensor=self.chunk_temp, prefix='\t\t') + print_tensor(tensor=self.chunk_temp, prefix="\t\t") if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: output.append("\tchunk total:\n") - print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t") if self.cuda_shard is not None: output.append("\tcuda shard:\n") - print_tensor(tensor=self.cuda_shard, prefix='\t\t') + print_tensor(tensor=self.cuda_shard, prefix="\t\t") if self.cpu_shard is not None: output.append("\tcpu shard:\n") - print_tensor(tensor=self.cpu_shard, prefix='\t\t') + print_tensor(tensor=self.cpu_shard, prefix="\t\t") memory_info = self.memory_usage - output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"])) if detailed: output.append("\ttensor state monitor:\n") for st in TensorState: output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) - return ''.join(output) + return "".join(output) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 1e96234326a9..957e41b02d49 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -20,27 +20,28 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): - self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') - v['init_device'] = self.device + self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size") + v["init_device"] = self.device self.chunk_groups: Dict[str, Deque[Chunk]] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.accessed_chunks: Set[Chunk] = set() self.accessed_mem: int = 0 - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def register_tensor(self, - tensor: torch.Tensor, - group_type: str, - config_key: int, - process_group: ProcessGroup, - cpu_offload: bool = False, - pin_memory: bool = False) -> None: + self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0} + + def register_tensor( + self, + tensor: torch.Tensor, + group_type: str, + config_key: int, + process_group: ProcessGroup, + cpu_offload: bool = False, + pin_memory: bool = False, + ) -> None: """ Register a tensor to the chunk manager. Then, the tensor should be accessed by `get_chunks`. @@ -94,25 +95,22 @@ def register_tensor(self, self.tensor_chunk_map[tensor] = chunk_group[-1] def close_all_groups(self): - """Close all the chunks of all groups. - """ + """Close all the chunks of all groups.""" for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) def access_chunk(self, chunk: Chunk) -> None: - """Make the chunk can be used for calculation. - """ + """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return self.__sub_memory_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': + if chunk.device_type == "cpu": chunk.shard_move(get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) def release_chunk(self, chunk: Chunk) -> None: - """Scatter the chunk in CUDA. - """ + """Scatter the chunk in CUDA.""" if chunk not in self.accessed_chunks: return if chunk.can_release: @@ -121,8 +119,7 @@ def release_chunk(self, chunk: Chunk) -> None: self.__add_memory_usage(chunk.memory_usage) def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: - """Move the shard of the chunk to the target device. - """ + """Move the shard of the chunk to the target device.""" if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memory_usage(chunk.memory_usage) @@ -130,14 +127,12 @@ def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = Fals self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """Transit tensor state according to pre-defined state machine. - """ + """Transit tensor state according to pre-defined state machine.""" chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) def reduce_chunk(self, chunk: Chunk) -> bool: - """Reduce or all reduce the chunk. - """ + """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) @@ -213,18 +208,17 @@ def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: def __repr__(self) -> str: msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + "Chunk Manager Information:\n", + "Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n", ] for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') + msg.append(f"Group {group_name}:\n") for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) + msg.append(f"[{i}] {chunk}\n") + return "".join(msg) def __get_chunk_group(self, group_name: str) -> Deque[Chunk]: - """Register a chunk group. - """ + """Register a chunk group.""" if group_name not in self.chunk_groups: self.chunk_groups[group_name] = deque() return self.chunk_groups[group_name] diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py index abaca5f8294d..24d8537bad90 100644 --- a/colossalai/zero/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -76,8 +76,9 @@ def _tensor_numel(local_param: ColoParameter) -> int: return local_param.numel() -def classify_params_by_dp_degree(param_order: OrderedParamGenerator, - process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]: +def classify_params_by_dp_degree( + param_order: OrderedParamGenerator, process_group: ProcessGroup +) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree @@ -105,14 +106,15 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, def search_chunk_configuration( - model: nn.Module, - search_range_m: float, - search_interval: int, # hidden size is the best value for the interval - min_chunk_size_m: float = 32, - filter_exlarge_params: bool = True, - strict_ddp_flag: bool = False, - process_group: Optional[ProcessGroup] = None, - memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: + model: nn.Module, + search_range_m: float, + search_interval: int, # hidden size is the best value for the interval + min_chunk_size_m: float = 32, + filter_exlarge_params: bool = True, + strict_ddp_flag: bool = False, + process_group: Optional[ProcessGroup] = None, + memstas: Optional[MemStats] = None, +) -> Tuple[Dict, int, int]: """search_chunk_configuration Search the chunk configuration for a model. @@ -168,7 +170,7 @@ def search_chunk_configuration( max_size = max(max_size, max(size_dict[key])) start_size = int(math.ceil(max_size / search_interval) * search_interval) - min_chunk_waste = float('+inf') + min_chunk_waste = float("+inf") best_chunk_size = start_size for chunk_size in range(start_size, start_size + search_range + 1, search_interval): diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index e98e9cf9c314..7a2ea360650b 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -5,8 +5,6 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.utils import is_ddp_ignored - from .manager import ChunkManager from .search_utils import search_chunk_configuration @@ -17,15 +15,17 @@ def safe_div(a, b): return a / b -def init_chunk_manager(model: nn.Module, - init_device: Optional[torch.device] = None, - hidden_dim: Optional[int] = None, - verbose: bool = False, - **kwargs) -> ChunkManager: +def init_chunk_manager( + model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + verbose: bool = False, + **kwargs, +) -> ChunkManager: if hidden_dim: search_interval = hidden_dim else: - search_interval = 1024 # defaults to 1024 + search_interval = 1024 # defaults to 1024 kwargs["search_interval"] = search_interval dist.barrier() @@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module, wasted_size /= mega_unit if verbose and dist.get_rank() == 0: - print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), - "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size), - "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), - sep='', - flush=True) + print( + "searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), + sep="", + flush=True, + ) dist.barrier() chunk_manager = ChunkManager(config_dict, init_device) diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py index 549635af4332..ab2ff8f920aa 100644 --- a/colossalai/zero/gemini/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Iterator, Optional, Tuple, Union import torch from torch import nn @@ -12,7 +12,7 @@ def _named_params_with_replica( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] @@ -21,16 +21,17 @@ def _named_params_with_replica( for name, val in mod._parameters.items(): if val is None: continue - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val -def _convert_to_coloparam(param: torch.nn.Parameter, - device: torch.device, - dtype=torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec: Optional[Any] = None) -> ColoParameter: - +def _convert_to_coloparam( + param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None, +) -> ColoParameter: if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. @@ -66,12 +67,13 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - - def __init__(self, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): + def __init__( + self, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, + ): """ Args: device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). @@ -89,6 +91,7 @@ def __init__(self, def _register_colo_modules(self): from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module + register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Embedding, ColoEmbedding()) @@ -105,25 +108,25 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): if type(param) is ColoParameter: continue - split = name.rfind('.') - if split >= 0: # param in submodule + split = name.rfind(".") + if split >= 0: # param in submodule module_name = name[:split] - param_name = name[split + 1:] + param_name = name[split + 1 :] else: - module_name = '' # param in current module + module_name = "" # param in current module param_name = name name_list.append((module_name, param_name)) - replaced_tensors = dict( - ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference for module_name, param_name in name_list: submodule = module.get_submodule(module_name) param = submodule.get_parameter(param_name) if param in replaced_tensors: colo_param = replaced_tensors[param] else: - colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, - self._default_dist_spec) + colo_param = _convert_to_coloparam( + param, self._device, self._dtype, self._default_pg, self._default_dist_spec + ) replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) @@ -136,11 +139,11 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): for param in module.parameters(): param_number += 1 - meta_param_number += (param.device.type == 'meta') + meta_param_number += param.device.type == "meta" for buffer in module.buffers(): buffer_number += 1 - meta_buffer_number += (buffer.device.type == 'meta') + meta_buffer_number += buffer.device.type == "meta" if meta_param_number > 0 and meta_param_number != param_number: raise ValueError("Meta parameters and valued parameters can not be in the same model") @@ -152,11 +155,13 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): buffer.data = buffer.data.to(device=self._device) -def post_process_colo_init_ctx(model: torch.nn.Module, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): +def post_process_colo_init_ctx( + model: torch.nn.Module, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, +): """post_process_colo_init_ctx This function is called after `ColoInitContext`. @@ -178,8 +183,8 @@ def post_process_colo_init_ctx(model: torch.nn.Module, # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") torch_params.append((n, p)) - for (n, param) in torch_params: - name_list = n.split('.') + for n, param in torch_params: + name_list = n.split(".") module = model for i in range(len(name_list) - 1): module = module._modules[name_list[i]] diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 918b08cd3150..580b497ce719 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger @@ -27,10 +27,10 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" __all__ = [ - 'GeminiDDP', + "GeminiDDP", ] @@ -54,27 +54,28 @@ class GeminiDDP(ModelWrapper): """ def __init__( - self, - module: torch.nn.Module, - chunk_config_dict: Optional[dict] = None, - chunk_init_device: torch.device = torch.device('cpu'), - placement_policy: str = "static", - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - warmup_non_model_data_ratio: float = 0.8, # only for auto placement - steady_cuda_cap_ratio: float = 0.9, # only for auto placement - search_range_m: int = 32, # chunk search options - hidden_dim: Optional[int] = None, # chunk search options - min_chunk_size_m: float = 32, # chunk search options - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True, - mixed_precision: torch.dtype = torch.float16, - process_group: Optional[ProcessGroup] = None, - memstats: Optional[MemStats] = None, # genimi memory stats - verbose: bool = False) -> None: + self, + module: torch.nn.Module, + chunk_config_dict: Optional[dict] = None, + chunk_init_device: torch.device = torch.device("cpu"), + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement + search_range_m: int = 32, # chunk search options + hidden_dim: Optional[int] = None, # chunk search options + min_chunk_size_m: float = 32, # chunk search options + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16, + process_group: Optional[ProcessGroup] = None, + memstats: Optional[MemStats] = None, # genimi memory stats + verbose: bool = False, + ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) if chunk_config_dict is not None: self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) @@ -82,22 +83,26 @@ def __init__( # some ugly hotfix for the compatibility with Lightning if search_range_m is None: search_range_m = 32 - self.chunk_manager = init_chunk_manager(model=module, - init_device=chunk_init_device, - hidden_dim=hidden_dim, - search_range_m=search_range_m, - min_chunk_size_m=min_chunk_size_m, - strict_ddp_flag=strict_ddp_mode, - process_group=process_group, - verbose=verbose) - self.gemini_manager = GeminiManager(placement_policy, - self.chunk_manager, - memstats, - shard_param_frac=shard_param_frac, - offload_optim_frac=offload_optim_frac, - offload_param_frac=offload_param_frac, - warmup_non_model_data_ratio=warmup_non_model_data_ratio, - steady_cuda_cap_ratio=steady_cuda_cap_ratio) + self.chunk_manager = init_chunk_manager( + model=module, + init_device=chunk_init_device, + hidden_dim=hidden_dim, + search_range_m=search_range_m, + min_chunk_size_m=min_chunk_size_m, + strict_ddp_flag=strict_ddp_mode, + process_group=process_group, + verbose=verbose, + ) + self.gemini_manager = GeminiManager( + placement_policy, + self.chunk_manager, + memstats, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, + ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() @@ -126,13 +131,15 @@ def __init__( self.param2name[param] = name for m_name, m_var in module.named_modules(): for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + '.' + p_name if m_name else p_name + param_name = m_name + "." + p_name if m_name else p_name self.name2param[param_name] = p_var - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) + self._init_chunks( + param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != "cuda", + pin_memory=pin_memory, + ) super().__init__(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() @@ -146,19 +153,18 @@ def __init__( def parameters(self, recurse: bool = True): return self.module.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.module.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix, recurse) def named_children(self): return self.module.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.module.named_modules(memo, prefix, remove_duplicate) @staticmethod @@ -184,11 +190,9 @@ def unwrap(self): # as save/load state dict is overwrited, only return self return self - def _get_non_persistent_buffers_set(self, - module, - memo: Optional[Set[nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def _get_non_persistent_buffers_set( + self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): r""" Args: memo: a memo to store the set of modules already added to the result @@ -204,19 +208,20 @@ def _get_non_persistent_buffers_set(self, if remove_duplicate: memo.add(module) self_non_persistent_set = set( - map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) + ) for name, sub_module in module._modules.items(): if sub_module is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name - child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, - remove_duplicate) + submodule_prefix = prefix + ("." if prefix else "") + name + child_non_persistent_set = self._get_non_persistent_buffers_set( + sub_module, memo, submodule_prefix, remove_duplicate + ) self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) return self_non_persistent_set def _post_forward(self): - """This function is only triggered for inference. - """ + """This function is only triggered for inference.""" access_list = list(self.chunk_manager.accessed_chunks) # we need to scatter all accessed chunks and move them to their original places for chunk in access_list: @@ -233,7 +238,8 @@ def forward(self, *args, **kwargs): # check whether we are in a inference mode grad_flag = torch.is_grad_enabled() if not grad_flag: - assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + assert ( + not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup() ), "You should run a completed iteration as your warmup iter" args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) @@ -250,8 +256,7 @@ def forward(self, *args, **kwargs): return outputs def _inference_forward(self, *args, **kwargs): - """This function is only triggered for inference. - """ + """This function is only triggered for inference.""" fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) if not self.scatter_after_inference: # gather all chunks @@ -287,12 +292,14 @@ def _post_backward(self): if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): error_params.append(self.param2name[param]) error_str = "\n\t".join(error_params) - raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with GeminiDDP.\n", - f"{error_str}") + raise RuntimeError( + "ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with GeminiDDP.\n", + f"{error_str}", + ) self._setup_grads_ptr() self._logger.debug( - f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) self.gemini_manager.post_iter() @@ -314,8 +321,10 @@ def grad_handle(self, p, grad): with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " - "Some unsupported torch function is operated upon this parameter.") + raise RuntimeError( + f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter." + ) self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) @@ -339,12 +348,9 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, - destination=None, - prefix='', - keep_vars=False, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16): + def state_dict( + self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16 + ): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -391,7 +397,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -399,8 +405,9 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. del temp_chunk return chunk_to_save_data - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, - dtype: torch.dtype) -> Dict: + def _get_param_to_save_data( + self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype + ) -> Dict: """ get param content from chunks. @@ -459,11 +466,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, destination[prefix + name] = buf if keep_vars else buf.detach() # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): destination[extra_state_key] = self.get_extra_state() - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned @@ -491,32 +500,38 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] - prefix = '' + prefix = "" local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) if strict: if len(unexpected_keys) > 0: error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys))) + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ), + ) if len(missing_keys) > 0: error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + 0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)) + ) return _IncompatibleKeys(missing_keys, unexpected_keys) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -564,19 +579,21 @@ def load(param_name, dest_tensor, copy_func): input_param = input_param[0] if input_param.shape != dest_tensor.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, - dest_tensor.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape) + ) return try: with torch.no_grad(): copy_func(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(state_key) @@ -600,15 +617,15 @@ def load_fp32_parameter(chunk_slice, data): for tensor, tensor_info in chunk.tensors_info.items(): parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk @@ -622,8 +639,10 @@ def load_fp32_parameter(chunk_slice, data): load(name, buf, buf.copy_) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", - torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if ( + getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state) + is not torch.nn.Module.set_extra_state + ): if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: @@ -634,7 +653,7 @@ def load_fp32_parameter(chunk_slice, data): if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name not in local_state: unexpected_keys.append(key) @@ -659,18 +678,22 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory) + self.chunk_manager.register_tensor( + tensor=p, + group_type="fp16_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) + self.chunk_manager.register_tensor( + tensor=fp32_p, + group_type="fp32_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) self.fp16_params.append(p) self.fp32_params.append(fp32_p) @@ -694,7 +717,7 @@ def _cast_buffers(self): if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) - def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: + def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None: """Convert parameter to ColoParameter in-place. Args: p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted @@ -709,12 +732,14 @@ def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) p.__class__ = ColoParameter p.__init__(p, requires_grad=requires_grad) - def state_dict_shard(self, - prefix: str = '', - keep_vars: bool = False, - max_shard_size: int = 1024, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]: + def state_dict_shard( + self, + prefix: str = "", + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16, + ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -770,8 +795,10 @@ def state_dict_shard(self, yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): extra_state = self.get_extra_state() block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index dbc2924858e6..480a14511b69 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -17,7 +17,6 @@ class TrainingPhase(Enum): class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() self._gemini_manager = gemini_manager @@ -40,7 +39,11 @@ def pre_op(self, params): def post_op(self, params): params = [p for p in params if not is_ddp_ignored(p)] for p in params: - tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + tensor_state = ( + TensorState.HOLD + if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad + else TensorState.HOLD_AFTER_BWD + ) self._chunk_manager.trans_tensor_state(p, tensor_state) def pre_forward(self, params: List[torch.Tensor]) -> None: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index b8e4717908f7..f7ff3f6cdd86 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -26,12 +26,13 @@ class GeminiManager: memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ - def __init__(self, - placement_policy: str, - chunk_manager: ChunkManager, - memstats: Optional[MemStats] = None, - **placement_kwargs) -> None: - + def __init__( + self, + placement_policy: str, + chunk_manager: ChunkManager, + memstats: Optional[MemStats] = None, + **placement_kwargs, + ) -> None: assert placement_policy in PlacementPolicyFactory.get_policy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) @@ -39,8 +40,9 @@ def __init__(self, self._premade_memstats_ = memstats is not None self._memstats = memstats - self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, - self._memstats) if policy_cls.need_mem_stats else None + self._mem_stats_collector = ( + ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None + ) self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -62,7 +64,7 @@ def reset_attributes(self): @property def need_warmup(self) -> bool: - return self.policy_name in ('auto', 'const') + return self.policy_name in ("auto", "const") def is_warmup(self): return self._warmup @@ -85,15 +87,14 @@ def pre_iter(self, *args): self._mem_stats_collector.start_collection() def post_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" if self._mem_stats_collector and self._warmup: self._mem_stats_collector.finish_collection() self._warmup = False self.reset_attributes() def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: - """ Adjust the layout of stateful tensors according to the information provided + """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE @@ -102,11 +103,13 @@ def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start - vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._placement_policy.evict_tensors( + can_evict_chunks=hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._d2h_volume += vol self._evict_time += evict_time @@ -118,12 +121,12 @@ def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == 'cuda': + if chunk.device_type == "cuda": if chunk.is_gathered: pass else: cuda_demand += chunk.chunk_mem - chunk.shard_mem - elif chunk.device_type == 'cpu': + elif chunk.device_type == "cpu": cuda_demand += chunk.chunk_mem else: raise RuntimeError @@ -159,6 +162,7 @@ def cuda_margin_mem(self) -> Optional[float]: def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: self._placement_policy.setup_grads_device(params, grads_device_map) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 0c593deff225..d785eda2dc12 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,34 +10,35 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam -from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.utils import disposable, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP -__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer'] +__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - module: GeminiDDP, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + module: GeminiDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.module = module def check_local_overflow(self) -> bool: @@ -77,25 +78,28 @@ class GeminiOptimizer(OptimizerWrapper): verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. """ - def __init__(self, - optim: Optimizer, - module: GeminiDDP, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - verbose: bool = False, - **defaults: Any): + def __init__( + self, + optim: Optimizer, + module: GeminiDDP, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + verbose: bool = False, + **defaults: Any, + ): super().__init__(optim) assert isinstance(module, GeminiDDP) - assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ - f"{_AVAIL_OPTIM_LIST}" + assert type(optim) in _AVAIL_OPTIM_LIST, ( + "You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}" + ) self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager @@ -118,8 +122,10 @@ def __init__(self, for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!") + warnings.warn( + f"Parameter `{name}` is ignored by DDP but requires gradient! " + "You should handle its optimizer update by yourself!" + ) else: ddp_param_list.append(param) @@ -132,14 +138,16 @@ def __init__(self, self.__init__optimizer() if module.mixed_precision is torch.float16: - self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin( + module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) elif module.mixed_precision is torch.bfloat16: self.mix_precision_mixin = BF16MixedPrecisionMixin() else: @@ -148,12 +156,15 @@ def __init__(self, self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr( - optim, 'num_fp32_shards_per_param', 0) >= 2 + self._should_move_fp32_params_h2d: bool = ( + self.gemini_manager.is_cuda_margin_mem_avail + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optim, "num_fp32_shards_per_param", 0) >= 2 + ) if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail: self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) @@ -161,7 +172,7 @@ def __init__(self, def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] begin, end = self.param_to_range[fake_param] chunk16 = chunk32.paired_chunk @@ -173,7 +184,7 @@ def _set_grad_ptr(self): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor.to(fake_param.device) @@ -198,7 +209,7 @@ def _calc_global_norm(self) -> float: group_to_norm[c16.torch_pg] = 0.0 group_to_norm[c16.torch_pg] += c16.l2_norm - c16.l2_norm = None # clear l2 norm + c16.l2_norm = None # clear l2 norm comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) for group, part_norm in group_to_norm.items(): @@ -230,9 +241,9 @@ def step(self, *args, **kwargs): if self.mix_precision_mixin.should_skip_step(): if self.verbose: - self._logger.info(f'Found overflow. Skip step') - self._clear_global_norm() # clear recorded norm - self.zero_grad() # reset all gradients + self._logger.info(f"Found overflow. Skip step") + self._clear_global_norm() # clear recorded norm + self.zero_grad() # reset all gradients self._update_fp16_params() return @@ -269,11 +280,11 @@ def _maybe_move_fp32_params(self): fp32_params_used_cuda_margin_mem = 0 for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] chunk16 = chunk32.paired_chunk - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: @@ -284,9 +295,9 @@ def _maybe_move_fp32_params(self): fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -294,14 +305,13 @@ def _maybe_move_fp32_params(self): def _register_states_(self): for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for val in state.values(): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) def __init__optimizer(self): - def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_info = local_chunk.tensors_info[local_param] if local_chunk.keep_gathered: @@ -313,10 +323,9 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_id = -1 for group in self.optim.param_groups: fake_params_list = list() - group_backup = {k: v for k, v in group.items() if k != 'params'} + group_backup = {k: v for k, v in group.items() if k != "params"} group_ids = [] - for param in group['params']: - + for param in group["params"]: # Record the mapping of id to current param. param_id += 1 self.id_to_real_params[param_id] = param @@ -337,12 +346,12 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): fake_params_list.append(fake_param) # Update self.optim.param_groups as well as backup group. - group['params'] = fake_params_list - group_backup['params'] = group_ids + group["params"] = fake_params_list + group_backup["params"] = group_ids self.param_groups_backup.append(group_backup) def get_offsets(self, param_id: int) -> tuple: - ''' + """ Args: param_id(int): The id of parameter. @@ -351,7 +360,7 @@ def get_offsets(self, param_id: int) -> tuple: shard_offset(int): Offset of its optimizer state shard relative to the whole optimizer state. shard_size(int): Length of parameter shard owned by current process. - ''' + """ if param_id not in self.id_to_fake_params: return -1, -1, -1 @@ -425,11 +434,11 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: if is_collector: states = self.optim.state[fake_param] for state_name in state_names: - if state_name == 'step': + if state_name == "step": # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. - collected_states[state_name] = torch.tensor(states['step'], - dtype=torch.float32, - requires_grad=False).cpu() + collected_states[state_name] = torch.tensor( + states["step"], dtype=torch.float32, requires_grad=False + ).cpu() else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() collected_states[state_name] = torch.reshape(state_tensor, param.shape) @@ -441,12 +450,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # Collector gets prepared for state collecting. if is_collector: for state_name in state_names: - if state_name == 'step': + if state_name == "step": # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() else: - collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32, - requires_grad=False).cpu() + collected_states[state_name] = torch.zeros( + param.numel(), dtype=torch.float32, requires_grad=False + ).cpu() # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None @@ -465,8 +475,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: shard_size = state_shard[2] if compacted_states is None: continue - self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset, - shard_size) + self.load_from_compacted_states( + compacted_states, collected_states, state_names, shard_offset, shard_size + ) # Reshape tensors if is_collector: @@ -476,14 +487,16 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: return collected_states - def pack_optimizer_states_to_tensor(self, - param_id: int, - state_names: list, - device: torch.device = torch.device('cuda'), - dtype: torch.dtype = torch.float32) -> torch.Tensor: - ''' + def pack_optimizer_states_to_tensor( + self, + param_id: int, + state_names: list, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ With param id given, pack its optimizer states into a compact tensor and return. - ''' + """ if param_id not in self.id_to_fake_params: return None @@ -493,7 +506,7 @@ def pack_optimizer_states_to_tensor(self, shard_size = param_range[1] - param_range[0] compacted_size = 0 for name in state_names: - if name == 'step': + if name == "step": compacted_size += 1 else: compacted_size += shard_size @@ -502,7 +515,7 @@ def pack_optimizer_states_to_tensor(self, next_state_offset = 0 for state_name, state_tensor in states.items(): # State 'step' needs special operation. - if state_name == 'step': + if state_name == "step": if isinstance(state_tensor, torch.Tensor): compacted_states[next_state_offset] = state_tensor[0].item() else: @@ -511,47 +524,53 @@ def pack_optimizer_states_to_tensor(self, next_state_offset += 1 else: assert state_tensor.numel() == shard_size - compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor) + compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor) next_state_offset += shard_size return compacted_states - def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list, - shard_start: int, shard_size: int): - ''' + def load_from_compacted_states( + self, + compacted_states: torch.Tensor, + collected_states: dict, + state_names: list, + shard_start: int, + shard_size: int, + ): + """ Given a tensor carrying compacted optimizer states, update these states to collected_states. - ''' + """ shard_end = shard_start + shard_size next_state_offset = 0 for state_name in state_names: - if state_name == 'step': - collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(), - dtype=torch.float32, - requires_grad=False).cpu() + if state_name == "step": + collected_states["step"].data = torch.tensor( + compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False + ).cpu() next_state_offset += 1 else: target_segment = collected_states[state_name][shard_start:shard_end] - target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size]) + target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size]) next_state_offset += shard_size def get_param_groups_for_saving(self) -> list: - ''' + """ Return the param_groups in Pytorch format when saving to checkpoint. - ''' + """ param_groups = copy.deepcopy(self.param_groups_backup) # To be compatible with pytorch checkpointing, # store extra hyperparameters used by pytorch Adam optimizer. torch_special_hyperparameters = { - 'amsgrad': False, - 'maximize': False, - 'foreach': None, - 'capturable': False, - 'differentiable': False, - 'fused': False + "amsgrad": False, + "maximize": False, + "foreach": None, + "capturable": False, + "differentiable": False, + "fused": False, } for group in param_groups: @@ -580,13 +599,13 @@ def state_dict(self, only_rank_0: bool = True) -> dict: so it should be called only when memory resources are abundant. """ state_dict = {} - state_dict['param_groups'] = self.get_param_groups_for_saving() + state_dict["param_groups"] = self.get_param_groups_for_saving() # Collect optimizer states. - state_dict['state'] = dict() + state_dict["state"] = dict() for param_id in self.id_to_real_params.keys(): dist.barrier() - state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) return state_dict def load_param_groups(self, saved_param_groups: list): @@ -601,13 +620,13 @@ def load_param_groups(self, saved_param_groups: list): for group in saved_param_groups: fake_params_list = list() - updated_group = {k: v for k, v in group.items() if k != 'params'} - for param_id in group['params']: + updated_group = {k: v for k, v in group.items() if k != "params"} + for param_id in group["params"]: if param_id not in self.id_to_fake_params: continue fake_param = self.id_to_fake_params[param_id] fake_params_list.append(fake_param) - updated_group['params'] = fake_params_list + updated_group["params"] = fake_params_list self.optim.param_groups.append(updated_group) def load_single_param_states(self, param_id: int, saved_states: dict): @@ -621,15 +640,14 @@ def cast(param, state_range, value, key=None): """ assert isinstance(value, torch.Tensor) ret_val = value - if (key == "step"): + if key == "step": assert value.numel() == 1 ret_val = int(value.item()) else: state_start, state_end = state_range - ret_val = torch.zeros(state_end - state_start, - dtype=torch.float32, - device=param.device, - requires_grad=False) + ret_val = torch.zeros( + state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False + ) ret_val.copy_(value.flatten()[state_start:state_end]) return ret_val @@ -642,7 +660,7 @@ def cast(param, state_range, value, key=None): updated_states = dict() for k, v in saved_states.items(): updated_states[k] = cast(fake_param, state_range, v, k) - del v # clean loaded states + del v # clean loaded states self.optim.state[fake_param].update(updated_states) def load_param_states(self, param_states: dict): @@ -658,8 +676,8 @@ def load_param_states(self, param_states: dict): def optimizer_loading_epilogue(self): # Epilogue when loading state_dict to pytorch optimizer. - self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. - self.optim.defaults.setdefault('differentiable', False) + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault("differentiable", False) def load_state_dict(self, state_dict: dict): """Loads optimizer state from complete optimizer state_dict. @@ -669,16 +687,15 @@ def load_state_dict(self, state_dict: dict): state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ - assert 'param_groups' in state_dict - assert 'state' in state_dict - self.load_param_groups(state_dict['param_groups']) - self.load_param_states(state_dict['state']) + assert "param_groups" in state_dict + assert "state" in state_dict + self.load_param_groups(state_dict["param_groups"]) + self.load_param_states(state_dict["state"]) self.optimizer_loading_epilogue() - def state_shard(self, - prefix: str = '', - max_shard_size: int = 1024, - only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]: + def state_shard( + self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True + ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing shards of optimizer states one by one. The max size of each dictionary shard is specified by ``max_shard_size``. @@ -694,7 +711,6 @@ def state_shard(self, sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): - dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) @@ -705,19 +721,20 @@ def state_shard(self, yield sharder.current_block, sharder.current_block_size def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('Gemini does not support clip_grad_by_value') + raise NotImplementedError("Gemini does not support clip_grad_by_value") - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> torch.Tensor: - warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> torch.Tensor: + warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") class GeminiAdamOptimizer(GeminiOptimizer): - def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: optimizer = HybridAdam(model.parameters(), **defaults) super().__init__(optimizer, model, **defaults) diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py index e1fe904ebf1a..cb7f626ff446 100644 --- a/colossalai/zero/gemini/memory_tracer/__init__.py +++ b/colossalai/zero/gemini/memory_tracer/__init__.py @@ -1,10 +1,14 @@ -from .param_runtime_order import OrderedParamGenerator # isort:skip -from .memory_stats import MemStats # isort:skip -from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip -from .memstats_collector import MemStatsCollector # isort:skip -from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip __all__ = [ - 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats', - 'OrderedParamGenerator' + "AsyncMemoryMonitor", + "SyncCudaMemoryMonitor", + "MemStatsCollector", + "ChunkMemStatsCollector", + "MemStats", + "OrderedParamGenerator", ] diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b93ad2c44104..b5e40a817e58 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -8,7 +8,6 @@ class ChunkMemStatsCollector(MemStatsCollector): - def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: """ @@ -27,10 +26,11 @@ def record_model_data_volume(self) -> None: record model data volume on cuda and cpu. """ if self._start_flag and not self.use_outside_memstats: - cuda_mem = self._chunk_manager.total_mem['cuda'] + cuda_mem = self._chunk_manager.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) @property def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 2a65d4b55409..513a6326d5f1 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -111,6 +111,7 @@ def finish(self): def _measure_usage(self): from colossalai.legacy.utils import colo_device_memory_used + max_usage = 0 while self.keep_measuring: max_usage = max( diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py index 02de6ecb97a9..1c141169f045 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch @@ -6,7 +6,6 @@ class MemStats(object): - def __init__(self) -> None: """ Store the non model data statistics used for Gemini and GeminiOptimizer. @@ -92,17 +91,17 @@ def param_order(self): return self._param_runtime_order def non_model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': + if device_type == "cuda": return self._non_model_data_cuda_list - elif device_type == 'cpu': + elif device_type == "cpu": return self._non_model_data_cpu_list else: raise TypeError def max_non_model_data(self, device_type: str) -> float: - if device_type == 'cuda': + if device_type == "cuda": return max(self._non_model_data_cuda_list) - elif device_type == 'cpu': + elif device_type == "cpu": return max(self._non_model_data_cpu_list) else: raise TypeError diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py index abb3dcc74b27..e4459831109a 100644 --- a/colossalai/zero/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -40,11 +40,12 @@ def next_period_non_model_data_usage(self, device_type: str) -> int: Returns: int: max non model data memory usage of current sampling period """ - assert not self._start_flag, 'Cannot get mem stats info during collection phase.' - assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' - assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ - f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + assert not self._start_flag, "Cannot get mem stats info during collection phase." + assert self._step_total > 0, "Cannot get mem stats info before collection phase." + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, ( + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, " f"step total {self._step_total}" + ) next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -60,9 +61,9 @@ def start_collection(self): def finish_collection(self): self.sample_overall_data() # self._step_total = len(self._sampling_time) - self._step_total = len(self._memstats.non_model_data_list('cuda')) + self._step_total = len(self._memstats.non_model_data_list("cuda")) self._start_flag = False - print(f'finish_collection {self._step_total}') + print(f"finish_collection {self._step_total}") # deprecated def record_model_data_volume(self) -> None: @@ -73,7 +74,7 @@ def record_model_data_volume(self) -> None: from colossalai.legacy.zero.gemini import StatefulTensor # The following code work for ZeroInitContext, which is deprecated in v0.1.12 - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) def sample_overall_data(self) -> None: diff --git a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py index 638c0533ce92..670edb9ec0d2 100644 --- a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py @@ -4,7 +4,6 @@ class ParamGenerator(ABC): - def append(self, param: torch.nn.Parameter): pass diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index 6656821fef74..b0d258824d2b 100644 --- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -10,10 +10,10 @@ from .memory_stats import MemStats -__all__ = ['RuntimeMemTracer'] +__all__ = ["RuntimeMemTracer"] -class RuntimeMemTracer(): +class RuntimeMemTracer: """RuntimeMemTracer for the module training using ColoParameter. Trace non-model memory usage during fwd+bwd process. diff --git a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py index b8f9a095f422..2a1a3745f81c 100644 --- a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -15,9 +15,9 @@ class ModuleInfos: - - def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, - parent_module: torch.nn.Module): + def __init__( + self, module: torch.nn.Module, module_name: str, module_full_name: str, parent_module: torch.nn.Module + ): self.module = module self.module_name = module_name self.module_full_name = module_full_name @@ -35,14 +35,13 @@ def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: self.module_info_list = [] def init_mem_stats(self, *inputs): - self.register_opnodes_recursively(self.module) self.refactor_module() self.module = self.module.cpu() self.module.train() - data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + data = [MetaTensor(torch.rand(inp.shape, device="meta"), fake_device="cpu") for inp in inputs] gm = symbolic_trace(self.module) interp = MetaInfoProp(gm) interp.propagate(*data) @@ -87,12 +86,13 @@ def recover_module(self): for modInfo in self.module_info_list: modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) - def register_opnodes_recursively(self, - module: torch.nn.Module, - name: str = "", - full_name: str = "", - parent_module: Optional[torch.nn.Module] = None): - + def register_opnodes_recursively( + self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None, + ): assert isinstance(module, torch.nn.Module) for child_name, child in module.named_children(): diff --git a/colossalai/zero/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py index 65f6ba775139..9faf81af63d7 100644 --- a/colossalai/zero/gemini/memory_tracer/utils.py +++ b/colossalai/zero/gemini/memory_tracer/utils.py @@ -14,7 +14,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]: """ if optim is None: return 0, 0 - assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()" + assert hasattr(optim, "get_memory_usage"), f"{type(optim)} has no attr get_memory_usage()" return optim.get_memory_usage() @@ -35,16 +35,16 @@ def _get_tensor_mem_use(t: Optional[torch.Tensor]): return 0, 0 assert isinstance(t, torch.Tensor) _cpu_mem_usage, _cuda_mem_usage = 0, 0 - if t.device.type == 'cpu': + if t.device.type == "cpu": _cpu_mem_usage += t.numel() * t.element_size() - elif t.device.type == 'cuda': + elif t.device.type == "cuda": _cuda_mem_usage += t.numel() * t.element_size() return _cuda_mem_usage, _cpu_mem_usage cuda_mem_usage = 0 cpu_mem_usage = 0 for param in model.parameters(): - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): t_cuda, t_cpu = param.colo_attr.get_memory_usage() cuda_mem_usage += t_cuda cpu_mem_usage += t_cpu diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index a35529723a68..8a74eb587b83 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -17,10 +17,9 @@ class PlacementPolicy(ABC): need_mem_stats: bool = False - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - **kwargs) -> None: + def __init__( + self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + ) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @@ -29,23 +28,25 @@ def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, f raise NotImplementedError @abstractmethod - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: raise NotImplementedError class StaticPlacementPolicy(PlacementPolicy): - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - shard_param_frac: float = 1.0, - offload_optim_frac: float = 0.0, - offload_param_frac: float = 0.0, - **kwargs) -> None: + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + shard_param_frac: float = 1.0, + offload_optim_frac: float = 0.0, + offload_param_frac: float = 0.0, + **kwargs, + ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): - warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0') + warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 self.shard_param_frac = shard_param_frac self.offload_optim_frac = offload_optim_frac @@ -66,13 +67,14 @@ def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, f for chunk in can_evict_chunks: if can_offload_chunk_mem <= self.keep_cuda_chunk_mem: break - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) # real saved mem is shard_mem, for simplicity we use chunk_mem can_offload_chunk_mem -= chunk.chunk_mem return 0, 0.0 - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params) offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac @@ -85,7 +87,7 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: device = get_current_device() else: - device = torch.device('cpu') + device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here offloaded_optim_chunk_mem += chunk.chunk_mem for p in params: @@ -97,12 +99,14 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - warmup_non_model_data_ratio: float = 0.8, - steady_cuda_cap_ratio: float = 0.9, - **kwargs) -> None: + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + warmup_non_model_data_ratio: float = 0.8, + steady_cuda_cap_ratio: float = 0.9, + **kwargs, + ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() @@ -110,13 +114,15 @@ def __init__(self, self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio - def evict_tensors(self, - can_evict_chunks: List[Chunk], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, - compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: + def evict_tensors( + self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs, + ) -> Tuple[int, float]: """ Evict tensors from CUDA device. @@ -135,13 +141,13 @@ def evict_tensors(self, """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = self.chunk_manager.total_mem['cuda'] + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data @@ -160,11 +166,13 @@ def evict_tensors(self, break self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) freed_cuda_model_data += chunk.chunk_mem if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) return freed_cuda_model_data, time() - start @staticmethod @@ -178,8 +186,9 @@ def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_li next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: for p in params: chunk = self.chunk_manager.get_chunk(p) # init offload optim settings @@ -187,13 +196,13 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ if chunk.keep_gathered: grads_device_map[p] = get_current_device() else: - grads_device_map[p] = torch.device('cpu') + grads_device_map[p] = torch.device("cpu") class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { - 'auto': AutoPlacementPolicy, - 'static': StaticPlacementPolicy, + "auto": AutoPlacementPolicy, + "static": StaticPlacementPolicy, } @staticmethod diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 0d92d32e5603..264099d22de2 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -27,16 +27,15 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): return total_temp -def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): - """Get a dfs module list of the given module. Its order is same as the order of creations of modules. - """ +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules.""" if memo is None: memo = set() if module not in memo: for name, submodule in module._modules.items(): if submodule is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name + submodule_prefix = prefix + ("." if prefix else "") + name for m in _get_dfs_module_list(submodule, memo, submodule_prefix): yield m @@ -60,10 +59,9 @@ def _get_shallow_copy_model(model: nn.Module): return old_to_new[model] -def get_static_torch_model(zero_ddp_model, - device=torch.device("cpu"), - dtype=torch.float32, - only_rank_0=True) -> torch.nn.Module: +def get_static_torch_model( + zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True +) -> torch.nn.Module: """Get a static torch.nn.Module model from the given GeminiDDP module. You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. @@ -79,6 +77,7 @@ def get_static_torch_model(zero_ddp_model, torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ from colossalai.zero.gemini.gemini_ddp import GeminiDDP + assert isinstance(zero_ddp_model, GeminiDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) @@ -86,15 +85,17 @@ def get_static_torch_model(zero_ddp_model, torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: - for (name, colo_module), (_, torch_module) in \ - zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + for (name, colo_module), (_, torch_module) in zip( + _get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model) + ): # clean the parameter list of the new torch module torch_module._parameters = OrderedDict() for sufix_param_name, param in colo_module.named_parameters(recurse=False): # get the full name of the parameter - full_param_name = name + ('.' if name else '') + sufix_param_name - assert full_param_name in state_dict, \ - f"Can not find parameter `{full_param_name}` in the GeminiDDP module" + full_param_name = name + ("." if name else "") + sufix_param_name + assert ( + full_param_name in state_dict + ), f"Can not find parameter `{full_param_name}` in the GeminiDDP module" state_param = state_dict[full_param_name] torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index ae3c1de3a5bc..270a6a6a4786 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,3 @@ from .low_level_optim import LowLevelZeroOptimizer -__all__ = ['LowLevelZeroOptimizer'] +__all__ = ["LowLevelZeroOptimizer"] diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index ece92fe02e28..ba1135940df0 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -44,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions): for partition_id in range(partitions_count): partition_tensors = partitions[partition_id] for item in partition_tensors: - tensor_index_mapping[item['index']] = len(new_tensor_list) - new_tensor_list.append(item['tensor']) + tensor_index_mapping[item["index"]] = len(new_tensor_list) + new_tensor_list.append(item["tensor"]) return new_tensor_list, tensor_index_mapping @@ -107,11 +107,13 @@ def split_by_dtype(tensor_list): return buckets -def reduce_tensor_dp_group(tensor: torch.Tensor, - dtype: Optional[torch.dtype] = None, - dst_local_rank: Optional[int] = None, - dst_global_rank: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None): +def reduce_tensor_dp_group( + tensor: torch.Tensor, + dtype: Optional[torch.dtype] = None, + dst_local_rank: Optional[int] = None, + dst_global_rank: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None, +): """ Reduce the tensor in the data parallel process group @@ -173,7 +175,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: return True return False @@ -184,8 +186,7 @@ def release_param_grad(tensor_list): def calculate_global_norm_from_list(norm_list): - """ Compute total from a list of norms - """ + """Compute total from a list of norms""" total_norm = 0.0 for norm in norm_list: total_norm += norm**2.0 @@ -221,7 +222,7 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro total_norm = 0.0 for g in gradients: param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + total_norm += param_norm.item() ** 2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) @@ -230,9 +231,9 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro if tp_group is not None: dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: total_norm = -1 return total_norm diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py index 7bcacfabfded..427973772f9c 100644 --- a/colossalai/zero/low_level/bookkeeping/__init__.py +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -3,4 +3,4 @@ from .parameter_store import ParameterStore from .tensor_bucket import TensorBucket -__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] +__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"] diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 2ebd122464f4..107d62dcbc0e 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -3,7 +3,6 @@ class BaseStore: - def __init__(self, torch_pg: ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 0ab10e25d407..2a75d704711a 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -9,7 +9,6 @@ class BucketStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) @@ -38,8 +37,7 @@ def num_elements_in_bucket(self) -> int: return self._num_elements_in_bucket def reset_num_elements_in_bucket(self): - """Set the number of elements in bucket to zero. - """ + """Set the number of elements in bucket to zero.""" self._num_elements_in_bucket = 0 @@ -54,7 +52,7 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): self._param_list.append(param) self._padding_size.append(padding_size) - self._num_elements_in_bucket += (param.numel() + padding_size) + self._num_elements_in_bucket += param.numel() + padding_size self.current_group_id = group_id # number of tensors in current bucket @@ -119,8 +117,7 @@ def get_param_id_of_grad(self, grad: Tensor) -> int: return self.grad_to_param_mapping[id(grad)] def reset(self): - """Reset the bucket storage after reduction, only release the tensors have been reduced - """ + """Reset the bucket storage after reduction, only release the tensors have been reduced""" cur_offset = self.offset_list.pop(0) self._param_list = self._param_list[cur_offset:] self._padding_size = self._padding_size[cur_offset:] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 2890b329a642..3ce688cfa930 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,13 +1,11 @@ from typing import List from torch import Tensor -from torch._utils import _flatten_dense_tensors from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index 63f7c5506069..e94fb4de9b9f 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -5,7 +5,6 @@ class ParameterStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index b32816a046cd..16ba8a6d6445 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -2,7 +2,6 @@ class TensorBucket: - def __init__(self, size): self._max_size = size self._current_size = 0 @@ -26,8 +25,7 @@ def add_to_bucket(self, tensor, allow_oversize=False): tensor_size = tensor.numel() if not allow_oversize and self.will_exceed_max_size(tensor_size): - msg = f"The param bucket max size {self._max_size} is exceeded" \ - + f"by tensor (size {tensor_size})" + msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})" raise RuntimeError(msg) self._bucket.append(tensor) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 0bdd6a3e2370..1bf5302efcfb 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -17,6 +17,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger + # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -32,19 +33,21 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - num_working_param_groups: int, - grad_store: GradientStore, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -57,32 +60,31 @@ def check_local_overflow(self) -> bool: class LowLevelZeroOptimizer(OptimizerWrapper): - """Optimizer used for ZeRO-1 and ZeRO-2. - """ + """Optimizer used for ZeRO-1 and ZeRO-2.""" def __init__( - self, - optimizer: Optimizer, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): - + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]['params'][0].dtype + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -115,7 +117,7 @@ def __init__( if forced_dtype: for group in self.optim.param_groups: - group_params = group['params'] + group_params = group["params"] for param in group_params: param.data = param.data.to(forced_dtype) self._dtype = forced_dtype @@ -134,7 +136,7 @@ def __init__( # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() - for param in param_group['params']: + for param in param_group["params"]: if param.requires_grad: group_params.append(param) @@ -148,7 +150,7 @@ def __init__( # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group['params'] = master_param_current_rank + param_group["params"] = master_param_current_rank # intialize communication stream for # communication-compuation overlapping @@ -164,15 +166,17 @@ def __init__( # initialize mixed precision mixin self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None if self._dtype is torch.float16: - self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups, - self._grad_store, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( + self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() @@ -185,17 +189,18 @@ def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available(), 'CUDA is required' + assert torch.cuda.is_available(), "CUDA is required" for param_group in self.optim.param_groups: - group_params = param_group['params'] + group_params = param_group["params"] for param in group_params: - assert param.dtype == self._dtype, \ - f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = 'cpu' if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -275,8 +280,10 @@ def _run_reduction(self): sync_tensor(flat_grads_per_rank[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, - param_id)) < self._world_size: + if ( + len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) + < self._world_size + ): self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) @@ -307,8 +314,10 @@ def _add_to_bucket(self, param, group_id): # if full, will reduce the grads already in the bucket # or got a grad of param from another group # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \ - group_id != self._bucket_store.current_group_id: + if ( + self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self._bucket_store.current_group_id + ): self._run_reduction() padding_size = self._param_store.get_param_padding_size(param) @@ -319,8 +328,9 @@ def _add_to_bucket(self, param, group_id): ################################ def backward(self, loss, retain_graph=False): - assert not(self._partition_grads and not self.require_grad_sync), \ - "ZeRO2(partition_grads) and no_sync are not compatible" + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) @@ -339,8 +349,9 @@ def backward(self, loss, retain_graph=False): self.zero_grad() def backward_by_grad(self, tensor, grad): - assert not(self._partition_grads and not self.require_grad_sync), \ - "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) @@ -380,14 +391,14 @@ def zero_grad(self, set_to_none=True): #################### def step(self, closure=None): - assert closure is None, 'closure is not supported by step()' + assert closure is None, "closure is not supported by step()" if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): self._grad_store.reset_all_gradients() if self._verbose: - self._logger.info(f'Found overflow. Skip step') + self._logger.info(f"Found overflow. Skip step") self.zero_grad() return @@ -428,7 +439,7 @@ def step(self, closure=None): self._grad_store.reset_grads_by_group_id(group_id) # update the params in the optimizer - self.optim.param_groups[group_id]['params'] = real_master_params[group_id] + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) @@ -445,16 +456,16 @@ def step(self, closure=None): # update working partition updated by the current rank dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): - master_working_param = self.optim.param_groups[group_id]['params'] + master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) ] dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) - working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] ############################# # Mixed Precision Utilities # @@ -466,14 +477,14 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): if self.mixed_precision_mixin is not None: div_scale = self.mixed_precision_mixin.get_grad_div_scale() - if self._clip_grad_norm > 0.: + if self._clip_grad_norm > 0.0: # norm is in fact norm*scale clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / div_scale) + grad.data.mul_(1.0 / div_scale) ############################ # Gradient Synchronization # @@ -518,18 +529,19 @@ def _pack_state(self, state: Dict) -> Dict: def pack_group(group): nonlocal start_index - packed = {k: v for k, v in group.items() if k != 'params'} + packed = {k: v for k, v in group.items() if k != "params"} param_mappings.update( - {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) - packed['params'] = [param_mappings[id(p)] for p in group['params']] - start_index += len(packed['params']) + {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings} + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) return packed param_groups = [pack_group(g) for g in self.optim.param_groups] # Remap state to use order indices as keys packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} - return {'state': packed_state, 'param_groups': param_groups} + return {"state": packed_state, "param_groups": param_groups} def state_dict(self) -> Dict: """Return a state_dict same with DDP @@ -541,14 +553,15 @@ def state_dict(self) -> Dict: for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': + if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] gather_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) ] dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -562,16 +575,16 @@ def load_state_dict(self, state_dict: Dict): state_dict (dict): A pytorch form state_dict """ zero_state_dict = copy.deepcopy(state_dict) - for param_idx, state in zero_state_dict['state'].items(): + for param_idx, state in zero_state_dict["state"].items(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': + if isinstance(v, torch.Tensor) and k != "step": padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -588,7 +601,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i ret_block = dict() ret_block_size = 0 - local_states = self.optim.state_dict()['state'] + local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) @@ -601,11 +614,12 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i working_param = self._param_store.master_to_working_param[id(master_param)] for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != 'step': - state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)] dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) - state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) current_block_size += state_tensor.numel() current_block[k] = state_tensor diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py index 90325fe0a704..ed873254e301 100644 --- a/colossalai/zero/wrapper.py +++ b/colossalai/zero/wrapper.py @@ -7,10 +7,9 @@ from .gemini import GeminiDDP -def zero_model_wrapper(model: nn.Module, - zero_stage: int = 1, - gemini_config: Optional[Dict] = None, - verbose: bool = False): +def zero_model_wrapper( + model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False +): """This wrapper function is used to wrap your training model for ZeRO DDP. Example: @@ -50,19 +49,21 @@ def zero_model_wrapper(model: nn.Module, return wrapped_model -def zero_optim_wrapper(model: nn.Module, - optimizer: torch.optim.Optimizer, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - optim_config: Optional[Dict] = None, - verbose: bool = False): +def zero_optim_wrapper( + model: nn.Module, + optimizer: torch.optim.Optimizer, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + optim_config: Optional[Dict] = None, + verbose: bool = False, +): """This wrapper function is used to wrap your training optimizer for ZeRO DDP. Args: @@ -95,20 +96,22 @@ def zero_optim_wrapper(model: nn.Module, else: config_dict = copy(optim_config) - config_dict['initial_scale'] = initial_scale - config_dict['growth_factor'] = growth_factor - config_dict['backoff_factor'] = backoff_factor - config_dict['growth_interval'] = growth_interval - config_dict['hysteresis'] = hysteresis - config_dict['min_scale'] = min_scale - config_dict['max_scale'] = max_scale + config_dict["initial_scale"] = initial_scale + config_dict["growth_factor"] = growth_factor + config_dict["backoff_factor"] = backoff_factor + config_dict["growth_interval"] = growth_interval + config_dict["hysteresis"] = hysteresis + config_dict["min_scale"] = min_scale + config_dict["max_scale"] = max_scale if zero_stage in [1, 2]: from colossalai.zero.low_level import LowLevelZeroOptimizer - config_dict['partition_grad'] = zero_stage == 2 - config_dict['clip_grad_norm'] = max_norm + + config_dict["partition_grad"] = zero_stage == 2 + config_dict["clip_grad_norm"] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) else: from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer - config_dict['clipping_norm'] = max_norm + + config_dict["clipping_norm"] = max_norm return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/examples/community/fp8/mnist/main.py b/examples/community/fp8/mnist/main.py index a534663d380f..2bb912dec247 100644 --- a/examples/community/fp8/mnist/main.py +++ b/examples/community/fp8/mnist/main.py @@ -13,13 +13,13 @@ try: from transformer_engine import pytorch as te + HAVE_TE = True except (ImportError, ModuleNotFoundError): HAVE_TE = False class Net(nn.Module): - def __init__(self, use_te=False): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) @@ -64,10 +64,12 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print(f"Train Epoch: {epoch} " - f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_idx / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}") + print( + f"Train Epoch: {epoch} " + f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}" + ) if args.dry_run: break @@ -75,13 +77,11 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): def calibrate(model, device, test_loader): """Calibration function.""" model.eval() - test_loss = 0 - correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=False, calibrating=True): - output = model(data) + model(data) def test(model, device, test_loader, use_fp8): @@ -94,15 +94,17 @@ def test(model, device, test_loader, use_fp8): data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=use_fp8): output = model(data) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print(f"\nTest set: Average loss: {test_loss:.4f}, " - f"Accuracy: {correct}/{len(test_loader.dataset)} " - f"({100. * correct / len(test_loader.dataset):.0f}%)\n") + print( + f"\nTest set: Average loss: {test_loss:.4f}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} " + f"({100. * correct / len(test_loader.dataset):.0f}%)\n" + ) def main(): @@ -163,10 +165,9 @@ def main(): default=False, help="For Saving the current Model", ) - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") + parser.add_argument( + "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" + ) parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only") parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine") args = parser.parse_args() @@ -215,7 +216,7 @@ def main(): if args.save_model or args.use_fp8_infer: torch.save(model.state_dict(), "mnist_cnn.pt") - print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer)) + print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer)) weights = torch.load("mnist_cnn.pt") model.load_state_dict(weights) test(model, device, test_loader, args.use_fp8_infer) diff --git a/examples/community/roberta/preprocessing/get_mask.py b/examples/community/roberta/preprocessing/get_mask.py index 74c97a63a9f3..f0ba8fe38501 100644 --- a/examples/community/roberta/preprocessing/get_mask.py +++ b/examples/community/roberta/preprocessing/get_mask.py @@ -1,13 +1,8 @@ import collections import logging -import os import random -import time -from enum import IntEnum -from random import choice import jieba -import torch jieba.setLogLevel(logging.CRITICAL) import re @@ -23,14 +18,15 @@ def map_to_numpy(data): return np.asarray(data) -class PreTrainingDataset(): - - def __init__(self, - tokenizer, - max_seq_length, - backend='python', - max_predictions_per_seq: int = 80, - do_whole_word_mask: bool = True): +class PreTrainingDataset: + def __init__( + self, + tokenizer, + max_seq_length, + backend="python", + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True, + ): self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.masked_lm_prob = 0.15 @@ -38,8 +34,8 @@ def __init__(self, self.do_whole_word_mask = do_whole_word_mask self.max_predictions_per_seq = max_predictions_per_seq self.vocab_words = list(tokenizer.vocab.keys()) - self.rec = re.compile('[\u4E00-\u9FA5]') - self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + self.rec = re.compile("[\u4E00-\u9FA5]") + self.whole_rec = re.compile("##[\u4E00-\u9FA5]") self.mlm_p = 0.15 self.mlm_mask_p = 0.8 @@ -64,7 +60,7 @@ def create_training_instance(self, instance): original_tokens = [] segment_ids = [] tokens.append("[CLS]") - original_tokens.append('[CLS]') + original_tokens.append("[CLS]") segment_ids.append(0) for index, token in enumerate(tokens_a): tokens.append(token) @@ -72,7 +68,7 @@ def create_training_instance(self, instance): segment_ids.append(0) tokens.append("[SEP]") - original_tokens.append('[SEP]') + original_tokens.append("[SEP]") segment_ids.append(0) # for token in tokens_b: @@ -83,11 +79,16 @@ def create_training_instance(self, instance): # segment_ids.append(1) # Get Masked LM predictions - if self.backend == 'c++': + if self.backend == "c++": output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions( - tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq, - self.masked_lm_prob) - elif self.backend == 'python': + tokens, + original_tokens, + self.vocab_words, + self.tokenizer.vocab, + self.max_predictions_per_seq, + self.masked_lm_prob, + ) + elif self.backend == "python": output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) # Convert to Ids @@ -99,20 +100,20 @@ def create_training_instance(self, instance): segment_ids.append(PAD) input_mask.append(PAD) masked_lm_output.append(-1) - return ([ + return [ map_to_numpy(input_ids), map_to_numpy(input_mask), map_to_numpy(segment_ids), map_to_numpy(masked_lm_output), - map_to_numpy([is_next]) - ]) + map_to_numpy([is_next]), + ] def create_masked_lm_predictions(self, tokens): cand_indexes = [] for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -160,7 +161,7 @@ def get_new_segment(self, segment): Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. :param segment: a sentence """ - seq_cws = jieba.lcut(''.join(segment)) + seq_cws = jieba.lcut("".join(segment)) seq_cws_dict = {x: 1 for x in seq_cws} new_segment = [] i = 0 @@ -174,10 +175,10 @@ def get_new_segment(self, segment): for length in range(3, 0, -1): if i + length > len(segment): continue - if ''.join(segment[i:i + length]) in seq_cws_dict: + if "".join(segment[i : i + length]) in seq_cws_dict: new_segment.append(segment[i]) for l in range(1, length): - new_segment.append('##' + segment[i + l]) + new_segment.append("##" + segment[i + l]) i += length has_add = True break @@ -190,7 +191,7 @@ def create_whole_masked_lm_predictions(self, tokens): """Creates the predictions for the masked LM objective.""" cand_indexes = [] - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue # Whole Word Masking means that if we mask all of the wordpieces @@ -202,14 +203,14 @@ def create_whole_masked_lm_predictions(self, tokens): # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) random.shuffle(cand_indexes) - output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" + output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) @@ -239,8 +240,9 @@ def create_whole_masked_lm_predictions(self, tokens): else: # 10% of the time, keep original if random.random() < 0.5: - masked_token = tokens[index][2:] if len(self.whole_rec.findall( - tokens[index])) > 0 else tokens[index] # 去掉"##" + masked_token = ( + tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index] + ) # 去掉"##" # 10% of the time, replace with random word else: masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] @@ -250,7 +252,9 @@ def create_whole_masked_lm_predictions(self, tokens): masked_lms.append( MaskedLMInstance( index=index, - label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index])) + label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index], + ) + ) assert len(masked_lms) <= num_to_predict masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_output = [-1] * len(output_tokens) diff --git a/examples/community/roberta/preprocessing/sentence_split.py b/examples/community/roberta/preprocessing/sentence_split.py index 76e8bd428723..8c83ce095582 100644 --- a/examples/community/roberta/preprocessing/sentence_split.py +++ b/examples/community/roberta/preprocessing/sentence_split.py @@ -14,17 +14,19 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s sent_list = [] try: if flag == "zh": - document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) + document = re.sub("(?P([。?!…](?![”’\"'])))", r"\g\n", document) + document = re.sub("(?P([。?!]|…{1,2})[”’\"'])", r"\g\n", document) elif flag == "en": - document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([?!.]["\']))', r'\g\n', - document) # Special quotation marks + document = re.sub("(?P([.?!](?![”’\"'])))", r"\g\n", document) + document = re.sub( + "(?P([?!.][\"']))", r"\g\n", document + ) # Special quotation marks else: - document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) + document = re.sub("(?P([。?!….?!](?![”’\"'])))", r"\g\n", document) - document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', - document) # Special quotation marks + document = re.sub( + "(?P(([。?!.!?]|…{1,2})[”’\"']))", r"\g\n", document + ) # Special quotation marks sent_list_ori = document.splitlines() for sent in sent_list_ori: @@ -46,36 +48,35 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None: - workers = 32 - if input_path[-1] == '/': + if input_path[-1] == "/": input_path = input_path[:-1] - cur_path = os.path.join(output_path, str(host) + '.txt') + cur_path = os.path.join(output_path, str(host) + ".txt") new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2) - with open(cur_path, 'w', encoding='utf-8') as f: + with open(cur_path, "w", encoding="utf-8") as f: for fi, fin_path in enumerate(fin_list): if not os.path.exists(os.path.join(input_path, fin_path[0])): continue - if '.json' not in fin_path[0]: + if ".json" not in fin_path[0]: continue print("Processing ", fin_path[0], " ", fi) - with open(os.path.join(input_path, fin_path[0]), 'r') as fin: - f_data = [l['content'] for l in json.load(fin)] + with open(os.path.join(input_path, fin_path[0]), "r") as fin: + f_data = [l["content"] for l in json.load(fin)] pool = multiprocessing.Pool(workers) all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) pool.close() - print('finished..') + print("finished..") cnt = 0 for d in tqdm(all_sent): for i in d: - f.write(i.strip() + '\n') - f.write(']]' + '\n') + f.write(i.strip() + "\n") + f.write("]]" + "\n") cnt += 1 # if cnt >= 2: # exit() @@ -86,7 +87,7 @@ def getFileSize(filepath, shard): for i in os.listdir(filepath): all_data.append(os.path.join(filepath, i)) all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) - ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] ans = sorted(ans, key=lambda x: x[1], reverse=True) per_size = all_size / shard real_shard = [] @@ -106,24 +107,24 @@ def getFileSize(filepath, shard): return real_shard -def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): +def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"): import socket + host = int(socket.gethostname().split(server_name)[-1]) fin_list = real_shard[server_num * base + host - 1] print(fin_list) - print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}") return fin_list, host -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--server_num', type=int, default=10, help='number of servers') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') - parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + parser.add_argument("--server_num", type=int, default=10, help="number of servers") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100") + parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus") + parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence") args = parser.parse_args() server_num = args.server_num @@ -137,7 +138,7 @@ def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): start = time.time() for index, shard in enumerate(real_shard): get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len) - print(f'cost {str(time.time() - start)}') + print(f"cost {str(time.time() - start)}") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/preprocessing/tokenize_mask.py b/examples/community/roberta/preprocessing/tokenize_mask.py index f3d49c3d965f..19dbaf5384de 100644 --- a/examples/community/roberta/preprocessing/tokenize_mask.py +++ b/examples/community/roberta/preprocessing/tokenize_mask.py @@ -1,7 +1,6 @@ import argparse import multiprocessing import os -import socket import time from random import shuffle @@ -29,8 +28,7 @@ def get_raw_instance(document, max_sequence_length=512): curr_seq = [] sz_idx = 0 while sz_idx < len(sizes): - - if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: + if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: curr_seq += document[sz_idx] sz_idx += 1 elif sizes[sz_idx] >= max_sequence_length_allowed: @@ -43,7 +41,7 @@ def get_raw_instance(document, max_sequence_length=512): result_list.append(curr_seq) curr_seq = [] - if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 @@ -58,33 +56,30 @@ def get_raw_instance(document, max_sequence_length=512): def split_numpy_chunk(path, tokenizer, pretrain_data, host): - documents = [] instances = [] s = time.time() - with open(path, encoding='utf-8') as fd: + with open(path, encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() # document = line # if len(document.split("")) <= 3: # continue - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print('read_file ', time.time() - s) + print("read_file ", time.time() - s) # documents = [x for x in documents if x] # print(len(documents)) # print(len(documents[0])) # print(documents[0][0:10]) - import multiprocessing - from typing import List ans = [] for docs in tqdm(documents): @@ -98,7 +93,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) sen_num = len(instances) seq_len = 512 @@ -114,7 +109,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): segment_ids[index] = mask_dict[2] masked_lm_output[index] = mask_dict[3] - with h5py.File(f'/output/{host}.h5', 'w') as hf: + with h5py.File(f"/output/{host}.h5", "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_ids) hf.create_dataset("segment_ids", data=segment_ids) @@ -124,45 +119,44 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): - - if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): - print(f'{file_name}.h5 exists') + if os.path.exists(os.path.join(output_path, f"{file_name}.h5")): + print(f"{file_name}.h5 exists") return documents = [] instances = [] s = time.time() - with open(input_path, 'r', encoding='utf-8') as fd: + with open(input_path, "r", encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print(f'read_file cost {time.time() - s}, length is {len(documents)}') + print(f"read_file cost {time.time() - s}, length is {len(documents)}") ans = [] s = time.time() pool = multiprocessing.Pool(worker) encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) - for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour="cyan"): ans.append(res) pool.close() print((time.time() - s) / 60) del documents instances = [] - for a in tqdm(ans, colour='MAGENTA'): + for a in tqdm(ans, colour="MAGENTA"): raw_ins = get_raw_instance(a, max_sequence_length=seq_len) instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) new_instances = [] for _ in range(dupe_factor): @@ -171,7 +165,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ shuffle(new_instances) instances = new_instances - print('after dupe_factor, len instance', len(instances)) + print("after dupe_factor, len instance", len(instances)) sentence_num = len(instances) input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) @@ -182,7 +176,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ s = time.time() pool = multiprocessing.Pool(worker) encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) - for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour="blue"): input_ids[index] = mask_dict[0] input_mask[index] = mask_dict[1] segment_ids[index] = mask_dict[2] @@ -190,7 +184,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ pool.close() print((time.time() - s) / 60) - with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + with h5py.File(os.path.join(output_path, f"{file_name}.h5"), "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_mask) hf.create_dataset("segment_ids", data=segment_ids) @@ -199,50 +193,48 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ del instances -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--max_predictions_per_seq', - type=int, - default=80, - help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') - parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') - parser.add_argument('--backend', - type=str, - default='python', - help='backend of mask token, python, c++, numpy respectively') + parser.add_argument("--tokenizer_path", type=str, required=True, default=10, help="path of tokenizer") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument( + "--max_predictions_per_seq", type=int, default=80, help="number of shards, e.g., 10, 50, or 100" + ) + parser.add_argument("--input_path", type=str, required=True, help="input path of shard which has split sentence") + parser.add_argument("--output_path", type=str, required=True, help="output path of h5 contains token id") + parser.add_argument( + "--backend", type=str, default="python", help="backend of mask token, python, c++, numpy respectively" + ) parser.add_argument( - '--dupe_factor', + "--dupe_factor", type=int, default=1, - help='specifies how many times the preprocessor repeats to create the input from the same article/document') - parser.add_argument('--worker', type=int, default=32, help='number of process') - parser.add_argument('--server_num', type=int, default=10, help='number of servers') + help="specifies how many times the preprocessor repeats to create the input from the same article/document", + ) + parser.add_argument("--worker", type=int, default=32, help="number of process") + parser.add_argument("--server_num", type=int, default=10, help="number of servers") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) - pretrain_data = PreTrainingDataset(tokenizer, - args.seq_len, - args.backend, - max_predictions_per_seq=args.max_predictions_per_seq) + pretrain_data = PreTrainingDataset( + tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq + ) data_len = len(os.listdir(args.input_path)) for i in range(data_len): - input_path = os.path.join(args.input_path, f'{i}.txt') + input_path = os.path.join(args.input_path, f"{i}.txt") if os.path.exists(input_path): start = time.time() - print(f'process {input_path}') - split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, - args.seq_len, i) + print(f"process {input_path}") + split_numpy_chunk_pool( + input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i + ) end_ = time.time() - print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) - print(f'has cost {(end_ - start) / 60}') - print('-' * 100) - print('') + print("memory:%.4f GB" % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) + print(f"has cost {(end_ - start) / 60}") + print("-" * 100) + print("") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py index e0702ceb59b0..35b809d80947 100644 --- a/examples/community/roberta/pretraining/arguments.py +++ b/examples/community/roberta/pretraining/arguments.py @@ -1,8 +1,6 @@ -from numpy import require - import colossalai -__all__ = ['parse_args'] +__all__ = ["parse_args"] def parse_args(): @@ -11,7 +9,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( @@ -23,65 +21,66 @@ def parse_args(): parser.add_argument( "--placement", type=str, - default='cpu', + default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) parser.add_argument( "--shardinit", - action='store_true', - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + action="store_true", + help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) - parser.add_argument('--lr', type=float, required=True, help='initial learning rate') - parser.add_argument('--epoch', type=int, required=True, help='number of epoch') - parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus") - parser.add_argument('--eval_data_path_prefix', - type=str, - required=True, - help='location of the evaluation data corpus') - parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer') - parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length') - parser.add_argument('--refresh_bucket_size', - type=int, - default=1, - help="This param makes sure that a certain task is repeated for this time steps to \ - optimize on the back propagation speed with APEX's DistributedDataParallel") - parser.add_argument("--max_predictions_per_seq", - "--max_pred", - default=80, - type=int, - help="The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument("--lr", type=float, required=True, help="initial learning rate") + parser.add_argument("--epoch", type=int, required=True, help="number of epoch") + parser.add_argument("--data_path_prefix", type=str, required=True, help="location of the train data corpus") + parser.add_argument( + "--eval_data_path_prefix", type=str, required=True, help="location of the evaluation data corpus" + ) + parser.add_argument("--tokenizer_path", type=str, required=True, help="location of the tokenizer") + parser.add_argument("--max_seq_length", type=int, default=512, help="sequence length") + parser.add_argument( + "--refresh_bucket_size", + type=int, + default=1, + help="This param makes sure that a certain task is repeated for this time steps to \ + optimize on the back propagation speed with APEX's DistributedDataParallel", + ) + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help="The maximum number of masked tokens in a sequence to be predicted.", + ) parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps") parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size") parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size") parser.add_argument("--num_workers", default=8, type=int, help="") - parser.add_argument("--async_worker", action='store_true', help="") + parser.add_argument("--async_worker", action="store_true", help="") parser.add_argument("--bert_config", required=True, type=str, help="location of config.json") - parser.add_argument("--wandb", action='store_true', help="use wandb to watch model") - parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name") + parser.add_argument("--wandb", action="store_true", help="use wandb to watch model") + parser.add_argument("--wandb_project_name", default="roberta", help="wandb project name") parser.add_argument("--log_interval", default=100, type=int, help="report interval") parser.add_argument("--log_path", type=str, required=True, help="log file which records train step") parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file") - parser.add_argument("--colossal_config", - type=str, - required=True, - help="colossal config, which contains zero config and so on") - parser.add_argument("--ckpt_path", - type=str, - required=True, - help="location of saving checkpoint, which contains model and optimizer") - parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug") - parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoint") parser.add_argument( - '--load_optimizer_lr', - default='', + "--colossal_config", type=str, required=True, help="colossal config, which contains zero config and so on" + ) + parser.add_argument( + "--ckpt_path", type=str, required=True, help="location of saving checkpoint, which contains model and optimizer" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--vscode_debug", action="store_true", help="use vscode to debug") + parser.add_argument("--load_pretrain_model", default="", type=str, help="location of model's checkpoint") + parser.add_argument( + "--load_optimizer_lr", + default="", type=str, - help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step") - parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint") - parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta") - parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing") + help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step", + ) + parser.add_argument("--resume_train", action="store_true", help="whether resume training from a early checkpoint") + parser.add_argument("--mlm", default="bert", type=str, help="model type, bert or deberta") + parser.add_argument("--checkpoint_activations", action="store_true", help="whether to use gradient checkpointing") args = parser.parse_args() return args diff --git a/examples/community/roberta/pretraining/bert_dataset_provider.py b/examples/community/roberta/pretraining/bert_dataset_provider.py index eaf165ed18f4..1d8cf2a910e9 100644 --- a/examples/community/roberta/pretraining/bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/bert_dataset_provider.py @@ -1,5 +1,4 @@ class BertDatasetProviderInterface: - def get_shard(self, index, shuffle=True): raise NotImplementedError diff --git a/examples/community/roberta/pretraining/evaluation.py b/examples/community/roberta/pretraining/evaluation.py index 009242cd1cf5..e1bce48023c3 100644 --- a/examples/community/roberta/pretraining/evaluation.py +++ b/examples/community/roberta/pretraining/evaluation.py @@ -19,23 +19,27 @@ def evaluate(model, args, logger, global_step, criterion): world_size = torch.distributed.get_world_size() with torch.no_grad(): - for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): - - timers('eval_shard_time').start() + timers("eval_shard_time").start() dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) # evaluate_dataset_provider.prefetch_shard(shard + 1) if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), - colour='MAGENTA', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), + colour="MAGENTA", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) - for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): - + for ( + step, + batch_data, + ) in ( + iterator_data + ): # tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): # batch_data = pretrain_dataset_provider.get_batch(batch_index) eval_step += 1 input_ids = batch_data[0].cuda() @@ -46,7 +50,7 @@ def evaluate(model, args, logger, global_step, criterion): output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - loss = criterion(output.logits, mlm_label) #prediction_scores + loss = criterion(output.logits, mlm_label) # prediction_scores evaluate_dataset_provider.prefetch_batch() eval_loss += loss.float().item() @@ -58,18 +62,18 @@ def evaluate(model, args, logger, global_step, criterion): if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'loss': cur_loss, - 'ppl': ppl, - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_eval( + {"loss": cur_loss, "ppl": ppl, "mins_batch": elapsed_time_per_iteration}, global_step + ) - eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + eval_log_str = ( + f"evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}" + ) logger.info(eval_log_str) - logger.info('-' * 100) - logger.info('') + logger.info("-" * 100) + logger.info("") evaluate_dataset_provider.release_shard() model.train() diff --git a/examples/community/roberta/pretraining/loss.py b/examples/community/roberta/pretraining/loss.py index 989c2bd5c450..636246292809 100644 --- a/examples/community/roberta/pretraining/loss.py +++ b/examples/community/roberta/pretraining/loss.py @@ -1,10 +1,9 @@ import torch -__all__ = ['LossForPretraining'] +__all__ = ["LossForPretraining"] class LossForPretraining(torch.nn.Module): - def __init__(self, vocab_size): super(LossForPretraining, self).__init__() self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) @@ -13,5 +12,5 @@ def __init__(self, vocab_size): def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) - total_loss = masked_lm_loss #+ next_sentence_loss + total_loss = masked_lm_loss # + next_sentence_loss return total_loss diff --git a/examples/community/roberta/pretraining/model/bert.py b/examples/community/roberta/pretraining/model/bert.py index abdf925d0540..31e3d7075a0c 100644 --- a/examples/community/roberta/pretraining/model/bert.py +++ b/examples/community/roberta/pretraining/model/bert.py @@ -59,7 +59,8 @@ # TokenClassification docstring _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _TOKEN_CLASS_EXPECTED_OUTPUT = ( - "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ") + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) _TOKEN_CLASS_EXPECTED_LOSS = 0.01 # QuestionAnswering docstring @@ -109,8 +110,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): import numpy as np import tensorflow as tf except ImportError: - logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") @@ -128,8 +131,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info(f"Skipping {'/'.join(name)}") continue pointer = model @@ -209,7 +214,7 @@ def forward( seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -236,12 +241,13 @@ def forward( class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) @@ -320,7 +326,7 @@ def forward( position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -360,7 +366,6 @@ def forward( class BertSelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -375,7 +380,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) @@ -385,8 +389,9 @@ def __init__(self, config, position_embedding_type=None): def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, - self.self.attention_head_size, self.pruned_heads) + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) @@ -419,12 +424,11 @@ def forward( output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -440,7 +444,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -455,7 +458,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertLayer(nn.Module): - def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -496,14 +498,15 @@ def forward( outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`") + " by setting `config.add_cross_attention=True`" + ) # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -517,14 +520,15 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, - self.seq_len_dim, attention_output) + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output @@ -540,7 +544,6 @@ def feed_forward_chunk(self, attention_output): class BertEncoder(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -573,14 +576,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -617,13 +619,17 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -634,7 +640,6 @@ def custom_forward(*inputs): class BertPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -650,7 +655,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -668,7 +672,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertLMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) @@ -689,7 +692,6 @@ def forward(self, hidden_states): class BertOnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -700,7 +702,6 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: class BertOnlyNSPHead(nn.Module): - def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) @@ -711,7 +712,6 @@ def forward(self, pooled_output): class BertPreTrainingHeads(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -943,8 +943,9 @@ def forward( `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1043,7 +1044,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForPreTraining(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1144,10 +1144,10 @@ def forward( ) -@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", - BERT_START_DOCSTRING) +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) class BertLMHeadModel(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1282,7 +1282,6 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1290,8 +1289,10 @@ def __init__(self, config): super().__init__(config) if config.is_decoder: - logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention.") + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) @@ -1357,7 +1358,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1380,10 +1381,9 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ raise ValueError("The PAD token should be defined for generation") attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) - dummy_token = torch.full((effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} @@ -1394,7 +1394,6 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ BERT_START_DOCSTRING, ) class BertForNextSentencePrediction(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1500,15 +1499,15 @@ def forward( BERT_START_DOCSTRING, ) class BertForSequenceClassification(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1604,13 +1603,13 @@ def forward( BERT_START_DOCSTRING, ) class BertForMultipleChoice(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, 1) @@ -1650,8 +1649,11 @@ def forward( attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.bert( input_ids, @@ -1696,7 +1698,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForTokenClassification(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1704,8 +1705,9 @@ def __init__(self, config): self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1782,7 +1784,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForQuestionAnswering(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): diff --git a/examples/community/roberta/pretraining/model/deberta_v2.py b/examples/community/roberta/pretraining/model/deberta_v2.py index 5fc284911e38..c7457942e164 100644 --- a/examples/community/roberta/pretraining/model/deberta_v2.py +++ b/examples/community/roberta/pretraining/model/deberta_v2.py @@ -23,7 +23,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss -from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutput, @@ -59,7 +58,6 @@ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler class ContextPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) @@ -138,15 +136,15 @@ def symbolic(g, self, mask, dim): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, - g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext class DropoutContext(object): - def __init__(self): self.dropout = 0 self.mask = None @@ -249,7 +247,6 @@ def get_context(self): # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm class DebertaV2SelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -265,7 +262,6 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 class DebertaV2Attention(nn.Module): - def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) @@ -303,7 +299,6 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 class DebertaV2Intermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -320,7 +315,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm class DebertaV2Output(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -337,7 +331,6 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 class DebertaV2Layer(nn.Module): - def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -372,17 +365,14 @@ def forward( class ConvLayer(nn.Module): - def __init__(self, config): super().__init__() kernel_size = getattr(config, "conv_kernel_size", 3) groups = getattr(config, "conv_groups", 1) self.conv_act = getattr(config, "conv_act", "tanh") - self.conv = nn.Conv1d(config.hidden_size, - config.hidden_size, - kernel_size, - padding=(kernel_size - 1) // 2, - groups=groups) + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config @@ -465,10 +455,9 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position(q, - hidden_states.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) return relative_pos def forward( @@ -498,14 +487,12 @@ def forward( rel_embeddings = self.get_rel_embedding() output_states = next_kv for i, layer_module in enumerate(self.layer): - if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -550,9 +537,9 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=output_states, - hidden_states=all_hidden_states, - attentions=all_attentions) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) def make_log_bucket_position(relative_pos, bucket_size, max_position): @@ -625,8 +612,10 @@ class DisentangledSelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0: - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads _attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) @@ -719,22 +708,28 @@ def forward( attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, - scale_factor) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) if rel_att is not None: attention_scores = attention_scores + rel_att attention_scores = attention_scores - attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), - attention_scores.size(-1)) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) # bsz x height x length x dimension attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = self.dropout(attention_probs) - context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), - value_layer) - context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), - context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) if output_attentions: @@ -745,10 +740,9 @@ def forward( def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position(q, - key_layer.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: @@ -766,22 +760,25 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # rel_embeddings = rel_embeddings.unsqueeze(0) # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: - pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) else: if "c2p" in self.pos_att_type: - pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) if "p2c" in self.pos_att_type: - pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) score = 0 # content->position @@ -792,9 +789,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ c2p_att = torch.gather( c2p_att, dim=-1, - index=c2p_pos.squeeze(0).expand([query_layer.size(0), - query_layer.size(1), - relative_pos.size(-1)]), + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), ) score += c2p_att / scale @@ -817,9 +812,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ p2c_att = torch.gather( p2c_att, dim=-1, - index=p2c_pos.squeeze(0).expand([query_layer.size(0), - key_layer.size(-2), - key_layer.size(-2)]), + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) score += p2c_att / scale @@ -999,7 +992,6 @@ def _set_gradient_checkpointing(self, module, value=False): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 class DebertaV2Model(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1042,8 +1034,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: @@ -1100,7 +1093,7 @@ def forward( sequence_output = encoded_layers[-1] if not return_dict: - return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):] + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] return BaseModelOutput( last_hidden_state=sequence_output, @@ -1174,7 +1167,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1191,7 +1184,6 @@ def forward( # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta class DebertaV2PredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -1210,7 +1202,6 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta class DebertaV2LMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = DebertaV2PredictionHeadTransform(config) @@ -1232,7 +1223,6 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta class DebertaV2OnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = DebertaV2LMPredictionHead(config) @@ -1251,7 +1241,6 @@ def forward(self, sequence_output): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1331,8 +1320,9 @@ def forward( label_index = (labels >= 0).nonzero() labels = labels.long() if label_index.size(0) > 0: - labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), - logits.size(1))) + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = CrossEntropyLoss() loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) @@ -1357,10 +1347,9 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1435,10 +1424,9 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1550,7 +1538,6 @@ def forward( DEBERTA_START_DOCSTRING, ) class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1606,8 +1593,11 @@ def forward( flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.deberta( flat_input_ids, diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py index 72c7bd852a40..09677a6195cb 100644 --- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -1,5 +1,3 @@ -import json -import logging import os import random import time @@ -12,14 +10,10 @@ from bert_dataset_provider import BertDatasetProviderInterface from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler - -import colossalai.utils as utils # Workaround because python functions are not picklable class WorkerInitObj(object): - def __init__(self, seed): self.seed = seed @@ -28,44 +22,46 @@ def __call__(self, id): random.seed(self.seed + id) -def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, - data_sampler): +def create_pretraining_dataset( + input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, data_sampler +): train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) - train_dataloader = DataLoader(train_data, - sampler=data_sampler(train_data), - batch_size=train_batch_size, - num_workers=num_workers, - worker_init_fn=worker_init, - pin_memory=True) + train_dataloader = DataLoader( + train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True, + ) return train_dataloader, len(train_data) class pretraining_dataset(Dataset): - def __init__(self, input_file, max_predictions_per_seq): self.input_file = input_file self.max_predictions_per_seq = max_predictions_per_seq f = h5py.File(input_file, "r") - keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'] + keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions"] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): - 'Denotes the total number of samples' + "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( - np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) ] return [input_ids, input_mask, segment_ids, masked_lm_labels] class NvidiaBertDatasetProvider(BertDatasetProviderInterface): - def __init__(self, args, evaluate=False): self.num_workers = args.num_workers self.max_seq_length = args.max_seq_length @@ -86,13 +82,13 @@ def __init__(self, args, evaluate=False): self.dataset_files = [ os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) - if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.data_path_prefix, f)) and "h5" in f ] else: self.dataset_files = [ os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) - if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and "h5" in f ] self.dataset_files.sort() @@ -120,7 +116,8 @@ def get_shard(self, index): num_workers=self.num_workers, train_batch_size=self.train_micro_batch_size_per_gpu, worker_init=self.worker_init, - data_sampler=self.data_sampler) + data_sampler=self.data_sampler, + ) else: self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) @@ -136,9 +133,15 @@ def release_shard(self): def prefetch_shard(self, index): self.data_file = self._get_shard_file(index) - self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, - self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, - self.data_sampler) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, + self.data_file, + self.max_predictions_per_seq, + self.num_workers, + self.train_micro_batch_size_per_gpu, + self.worker_init, + self.data_sampler, + ) def get_batch(self, batch_iter): return batch_iter diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py index e6a393a57dda..1370b413b712 100644 --- a/examples/community/roberta/pretraining/pretrain_utils.py +++ b/examples/community/roberta/pretraining/pretrain_utils.py @@ -1,24 +1,12 @@ -import logging import os import sys import torch import transformers -from torch.optim import AdamW -from transformers import ( - AutoModelForMaskedLM, - AutoTokenizer, - BertForPreTraining, - GPT2Config, - GPT2LMHeadModel, - RobertaConfig, - RobertaForMaskedLM, - get_linear_schedule_with_warmup, -) +from transformers import get_linear_schedule_with_warmup from colossalai.legacy.core import global_context as gpc -from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.nn.optimizer import FusedAdam, HybridAdam +from colossalai.nn.optimizer import HybridAdam sys.path.append(os.getcwd()) from collections import OrderedDict @@ -27,7 +15,7 @@ from model.bert import BertForMaskedLM from model.deberta_v2 import DebertaV2ForMaskedLM -__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] +__all__ = ["get_model", "get_optimizer", "get_lr_scheduler", "get_dataloader_for_pretraining"] def get_new_state_dict(state_dict, start_index=13): @@ -39,7 +27,6 @@ def get_new_state_dict(state_dict, start_index=13): class LMModel(nn.Module): - def __init__(self, model, config, args): super().__init__() @@ -55,11 +42,10 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): def get_model(args, logger): - - if args.mlm == 'bert': + if args.mlm == "bert": config = transformers.BertConfig.from_json_file(args.bert_config) model = BertForMaskedLM(config) - elif args.mlm == 'deberta_v2': + elif args.mlm == "deberta_v2": config = transformers.DebertaV2Config.from_json_file(args.bert_config) model = DebertaV2ForMaskedLM(config) else: @@ -68,11 +54,13 @@ def get_model(args, logger): if len(args.load_pretrain_model) > 0: assert os.path.exists(args.load_pretrain_model) # load_checkpoint(args.load_pretrain_model, model, strict=False) - m_state_dict = torch.load(args.load_pretrain_model, - map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + m_state_dict = torch.load( + args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}") + ) # new_state_dict = get_new_state_dict(m_state_dict) - model.load_state_dict(m_state_dict, - strict=True) # must insure that every process have identical parameters !!!!!!! + model.load_state_dict( + m_state_dict, strict=True + ) # must insure that every process have identical parameters !!!!!!! logger.info("load model success") numel = sum([p.numel() for p in model.parameters()]) @@ -85,40 +73,36 @@ def get_model(args, logger): def get_optimizer(model, lr): param_optimizer = list(model.named_parameters()) - no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + no_decay = ["bias", "gamma", "beta", "LayerNorm"] # configure the weight decay for bert models - optimizer_grouped_parameters = [{ - 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], - 'weight_decay': 0.1 - }, { - 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], - 'weight_decay': 0.0 - }] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.1}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) return optimizer def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): # warmup_steps = int(total_steps * warmup_ratio) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=total_steps, - last_epoch=last_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch + ) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) return lr_scheduler def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): - model_path = path + '_pytorch_model.bin' - optimizer_lr_path = path + '.op_lrs' + model_path = path + "_pytorch_model.bin" + optimizer_lr_path = path + ".op_lrs" checkpoint = {} - checkpoint['optimizer'] = optimizer.state_dict() - checkpoint['lr_scheduler'] = lr_scheduler.state_dict() - checkpoint['epoch'] = epoch - checkpoint['shard'] = shard - checkpoint['global_step'] = global_step - model_state = model.state_dict() #each process must run model.state_dict() + checkpoint["optimizer"] = optimizer.state_dict() + checkpoint["lr_scheduler"] = lr_scheduler.state_dict() + checkpoint["epoch"] = epoch + checkpoint["shard"] = shard + checkpoint["global_step"] = global_step + model_state = model.state_dict() # each process must run model.state_dict() if gpc.get_global_rank() == 0: torch.save(checkpoint, optimizer_lr_path) torch.save(model_state, model_path) diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index fa6457cab328..5396de6935cb 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -17,16 +17,13 @@ import colossalai from colossalai.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import GeminiOptimizer def main(): - args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) @@ -37,20 +34,17 @@ def main(): logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: - colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + colossalai.launch( + config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend + ) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(config={}) # args.colossal_config + colossalai.launch_from_torch(config={}) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( - f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + f"launch_from_torch, world size: {torch.distributed.get_world_size()} | " + + f"ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}" ) log_args(logger, args) @@ -59,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - init_dev = get_current_device() + get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -72,10 +66,9 @@ def main(): raise RuntimeError("You can only use shardinit with CAI_Gemini") # build GPT model - with ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): + with ColoInitContext( + device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + ): config, model, numel = get_model(args, logger) # assign running configurations @@ -83,13 +76,15 @@ def main(): if args.distplan.startswith("CAI_ZeRO"): optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.hidden_size, - search_range_m=128) - optim_config = dict(gpu_margin_mem_ratio=0.) + gemini_config = dict( + strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_m=128, + ) + optim_config = dict(gpu_margin_mem_ratio=0.0) else: raise RuntimeError @@ -109,7 +104,7 @@ def main(): model = zero_model_wrapper(model, zero_stage, gemini_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) - logger.info(get_mem_info(prefix='After init optim, ')) + logger.info(get_mem_info(prefix="After init optim, ")) else: config, model, numel = get_model(args, logger) @@ -118,13 +113,19 @@ def main(): if torch.distributed.get_rank() == 0: os.mkdir(os.path.join(args.ckpt_path, launch_time)) - logger.info(f'Model numel: {numel}') + logger.info(f"Model numel: {numel}") get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) # 144003367 is is the length of the entire dataset # len(dataloader) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size + steps_per_epoch = ( + 144003367 + // world_size + // args.train_micro_batch_size_per_gpu + // args.gradient_accumulation_steps + // args.refresh_bucket_size + ) total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -134,25 +135,25 @@ def main(): global_step = 0 if args.resume_train: assert os.path.exists(args.load_optimizer_lr) - o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') - o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 - optimizer.load_state_dict(o_l_state_dict['optimizer']) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location="cpu") + o_l_state_dict["lr_scheduler"]["last_epoch"] = o_l_state_dict["lr_scheduler"]["last_epoch"] - 1 + optimizer.load_state_dict(o_l_state_dict["optimizer"]) # o_l_state_dict['lr_scheduler']['last_epoch'] - lr_scheduler = get_lr_scheduler(optimizer, - total_steps=total_steps, - last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) + lr_scheduler = get_lr_scheduler( + optimizer, total_steps=total_steps, last_epoch=o_l_state_dict["lr_scheduler"]["last_epoch"] + ) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() - lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + lr_scheduler.load_state_dict(o_l_state_dict["lr_scheduler"]) - start_epoch = o_l_state_dict['epoch'] - start_shard = o_l_state_dict['shard'] + 1 + start_epoch = o_l_state_dict["epoch"] + start_shard = o_l_state_dict["shard"] + 1 # global_step = o_l_state_dict['global_step'] + 1 logger.info( - f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + f"resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}" ) criterion = LossForPretraining(config.vocab_size) @@ -160,34 +161,32 @@ def main(): # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) - logger.info(get_mem_info(prefix='After init model, ')) + logger.info(get_mem_info(prefix="After init model, ")) - best_loss = None eval_loss = 0 train_loss = 0 timers = get_timers() - timers('interval_time').start() - timers('epoch_time').start() - timers('shard_time').start() + timers("interval_time").start() + timers("epoch_time").start() + timers("shard_time").start() for epoch in range(start_epoch, args.epoch): - for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): - dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.train_micro_batch_size_per_gpu // world_size), - colour='cyan', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour="cyan", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) model.train() for step, batch_data in iterator_data: - # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") @@ -209,56 +208,70 @@ def main(): global_step += 1 - if global_step % args.log_interval == 0 and global_step != 0 \ - and torch.distributed.get_rank() == 0: - elapsed_time = timers('interval_time').elapsed(reset=False) + if global_step % args.log_interval == 0 and global_step != 0 and torch.distributed.get_rank() == 0: + elapsed_time = timers("interval_time").elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( - numel, args, config, elapsed_time, global_step, world_size) + numel, args, config, elapsed_time, global_step, world_size + ) cur_loss = train_loss / args.log_interval current_lr = lr_scheduler.get_last_lr()[0] - log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + log_str = ( + f"| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}" + ) logger.info(log_str, print_=False) if args.wandb: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_train( { - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + "lr": current_lr, + "loss": cur_loss, + "ppl": math.exp(cur_loss), + "mins_batch": elapsed_time_per_iteration, + }, + global_step, + ) train_loss = 0 logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') - logger.info('*' * 100) + logger.info("*" * 100) eval_loss += evaluate(model, args, logger, global_step, criterion) - save_ckpt(model, optimizer, lr_scheduler, - os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, - shard, global_step) + save_ckpt( + model, + optimizer, + lr_scheduler, + os.path.join(args.ckpt_path, launch_time, f"epoch-{epoch}_shard-{shard}_" + launch_time), + epoch, + shard, + global_step, + ) eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' - + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') - logger.info('-' * 100) + + f"eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}" + ) + logger.info("-" * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'all_eval_shard_loss': eval_loss, - }, epoch) + tensorboard_log.log_eval( + { + "all_eval_shard_loss": eval_loss, + }, + epoch, + ) start_shard = 0 eval_loss = 0 pretrain_dataset_provider.release_shard() - logger.info('Congratulation, training has finished!!!') + logger.info("Congratulation, training has finished!!!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/community/roberta/pretraining/utils/WandbLog.py b/examples/community/roberta/pretraining/utils/WandbLog.py index b68ba8387dcd..d73393c348d8 100644 --- a/examples/community/roberta/pretraining/utils/WandbLog.py +++ b/examples/community/roberta/pretraining/utils/WandbLog.py @@ -6,7 +6,6 @@ class WandbLog: - @classmethod def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): wandb.init(project=project, notes=notes, name=name, config=config) @@ -23,7 +22,6 @@ def log(cls, result, model=None, gradient=None): class TensorboardLog: - def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): if not os.path.exists(location): os.mkdir(location) @@ -31,12 +29,12 @@ def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localt def log_train(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/train', v, step) + self.writer.add_scalar(f"{k}/train", v, step) def log_eval(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/eval', v, step) + self.writer.add_scalar(f"{k}/eval", v, step) def log_zeroshot(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}_acc/eval', v, step) + self.writer.add_scalar(f"{k}_acc/eval", v, step) diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py index 1fcaa428b277..e95b6efda4c8 100644 --- a/examples/community/roberta/pretraining/utils/exp_util.py +++ b/examples/community/roberta/pretraining/utils/exp_util.py @@ -12,8 +12,8 @@ def logging(s, log_path, print_=True, log_=True): if print_: print(s) if log_: - with open(log_path, 'a+') as f_log: - f_log.write(s + '\n') + with open(log_path, "a+") as f_log: + f_log.write(s + "\n") def get_logger(log_path, **kwargs): @@ -22,22 +22,22 @@ def get_logger(log_path, **kwargs): def create_exp_dir(dir_path, scripts_to_save=None, debug=False): if debug: - print('Debug Mode : no experiment dir created') + print("Debug Mode : no experiment dir created") return functools.partial(logging, log_path=None, log_=False) if not os.path.exists(dir_path): os.makedirs(dir_path) - print('Experiment dir : {}'.format(dir_path)) + print("Experiment dir : {}".format(dir_path)) if scripts_to_save is not None: - script_path = os.path.join(dir_path, 'scripts') + script_path = os.path.join(dir_path, "scripts") if not os.path.exists(script_path): os.makedirs(script_path) for script in scripts_to_save: - dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + dst_file = os.path.join(dir_path, "scripts", os.path.basename(script)) shutil.copyfile(script, dst_file) - return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + return get_logger(log_path=os.path.join(dir_path, "log.txt")) def get_cpu_mem(): @@ -48,8 +48,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -59,11 +59,12 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_parameters_in_billions(model, world_size=1): gpus_per_model = world_size - approx_parameters_in_billions = sum([ - sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() - for p in model_module.parameters()]) - for model_module in model - ]) + approx_parameters_in_billions = sum( + [ + sum([p.ds_numel if hasattr(p, "ds_id") else p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ) return approx_parameters_in_billions * gpus_per_model / (1e9) @@ -71,13 +72,13 @@ def get_parameters_in_billions(model, world_size=1): def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): gpus_per_model = 1 batch_size = args.train_micro_batch_size_per_gpu - samples_per_model = batch_size * args.max_seq_length - model_replica_count = world_size / gpus_per_model + batch_size * args.max_seq_length + world_size / gpus_per_model approx_parameters_in_billions = numel elapsed_time_per_iter = iteration_time / total_iterations samples_per_second = batch_size / elapsed_time_per_iter - #flops calculator + # flops calculator hidden_size = config.hidden_size num_layers = config.num_hidden_layers vocab_size = config.vocab_size @@ -87,9 +88,9 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations, # The factor of 4 is when used with activation check-pointing, # otherwise it will be 3. checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 - flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * - (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + - (vocab_size / (16. * num_layers * hidden_size))) + flops_per_iteration = ( + 24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2) + ) * (1.0 + (args.max_seq_length / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size))) tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions @@ -106,9 +107,9 @@ def synchronize(): def log_args(logger, args): - logger.info('--------args----------') - message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) - message += '\n' - message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info("--------args----------") + message = "\n".join([f"{k:<30}: {v}" for k, v in vars(args).items()]) + message += "\n" + message += "\n".join([f"{k:<30}: {v}" for k, v in gpc.config.items()]) logger.info(message) - logger.info('--------args----------\n') + logger.info("--------args----------\n") diff --git a/examples/community/roberta/pretraining/utils/global_vars.py b/examples/community/roberta/pretraining/utils/global_vars.py index 9eef19e71614..176c0a5b3474 100644 --- a/examples/community/roberta/pretraining/utils/global_vars.py +++ b/examples/community/roberta/pretraining/utils/global_vars.py @@ -16,21 +16,21 @@ def set_global_variables(launch_time, tensorboard_path): def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") _GLOBAL_TIMERS = Timers() def _set_tensorboard_writer(launch_time, tensorboard_path): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") if torch.distributed.get_rank() == 0: - _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f"/{launch_time}", launch_time) def get_timers(): """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") return _GLOBAL_TIMERS @@ -42,12 +42,12 @@ def get_tensorboard_writer(): def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" - assert var is not None, '{} is not initialized.'.format(name) + assert var is not None, "{} is not initialized.".format(name) def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) + assert var is None, "{} is already initialized.".format(name) class _Timer: @@ -68,9 +68,9 @@ def start(self): def stop(self): """Stop the timer.""" - assert self.started_, 'timer is not started' + assert self.started_, "timer is not started" torch.cuda.synchronize() - self.elapsed_ += (time.time() - self.start_time) + self.elapsed_ += time.time() - self.start_time self.started_ = False def reset(self): @@ -114,15 +114,15 @@ def write(self, names, writer, iteration, normalizer=1.0, reset=False): assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + '-time', value, iteration) + writer.add_scalar(name + "-time", value, iteration) def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 - string = 'time (ms)' + string = "time (ms)" for name in names: elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += ' | {}: {:.2f}'.format(name, elapsed_time) + string += " | {}: {:.2f}".format(name, elapsed_time) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): print(string, flush=True) diff --git a/examples/community/roberta/pretraining/utils/logger.py b/examples/community/roberta/pretraining/utils/logger.py index 75c9bf4bef25..9913892b89e9 100644 --- a/examples/community/roberta/pretraining/utils/logger.py +++ b/examples/community/roberta/pretraining/utils/logger.py @@ -1,16 +1,14 @@ import logging -import os import torch.distributed as dist -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) logger = logging.getLogger(__name__) -class Logger(): - +class Logger: def __init__(self, log_path, cuda=False, debug=False): self.logger = logging.getLogger(__name__) self.cuda = cuda @@ -23,8 +21,8 @@ def info(self, message, log_=True, print_=True, *args, **kwargs): self.logger.info(message, *args, **kwargs) if log_: - with open(self.log_path, 'a+') as f_log: - f_log.write(message + '\n') + with open(self.log_path, "a+") as f_log: + f_log.write(message + "\n") def error(self, message, *args, **kwargs): self.logger.error(message, *args, **kwargs) diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 0c7f42ded318..b63896524909 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -132,7 +132,7 @@ bash train_colossalai.sh ``` It is important for you to configure your volume mapping in order to get the best training experience. -1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. +1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. 2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v :/root/.cache/huggingface`, where you need to replace the `` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`. 3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index f3ae3ddb5ff6..72dc05b649a4 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -80,7 +80,7 @@ data: lightning: trainer: - accelerator: 'gpu' + accelerator: 'gpu' devices: 8 log_gpu_memory: all max_epochs: 2 diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py index a12492c95a16..11bd0c5954a2 100644 --- a/examples/images/diffusion/ldm/data/base.py +++ b/examples/images/diffusion/ldm/data/base.py @@ -1,17 +1,15 @@ -import math import os -from abc import abstractmethod import cv2 import numpy as np import torch -from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset +from torch.utils.data import IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): - ''' + """ Define an interface to make the IterableDatasets for text2img data chainable - ''' + """ def __init__(self, file_path: str, rank, world_size): super().__init__() @@ -20,8 +18,8 @@ def __init__(self, file_path: str, rank, world_size): self.file_list = [] self.txt_list = [] self.info = self._get_file_info(file_path) - self.start = self.info['start'] - self.end = self.info['end'] + self.start = self.info["start"] + self.end = self.info["end"] self.rank = rank self.world_size = world_size @@ -33,7 +31,7 @@ def __init__(self, file_path: str, rank, world_size): self.num_records = self.end - self.start self.valid_ids = [i for i in range(self.end)] - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") def __len__(self): # return self.iter_end - self.iter_start @@ -48,7 +46,7 @@ def _sample_generator(self, start, end): for idx in range(start, end): file_name = self.file_list[idx] txt_name = self.txt_list[idx] - f_ = open(txt_name, 'r') + f_ = open(txt_name, "r") txt_ = f_.read() f_.close() image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) @@ -57,18 +55,17 @@ def _sample_generator(self, start, end): yield {"txt": txt_, "image": image} def _get_file_info(self, file_path): - info = \ - { + info = { "start": 1, "end": 0, } - self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i] for folder in self.folder_list: - files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] - txts = [k.replace('jpg', 'txt') for k in files] + files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i] + txts = [k.replace("jpg", "txt") for k in files] self.file_list.extend(files) self.txt_list.extend(txts) - info['end'] = len(self.file_list) + info["end"] = len(self.file_list) # with open(file_path, 'r') as fin: # for _ in enumerate(fin): # info['end'] += 1 diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py index 53cd61263b47..85c6e1b5dd38 100644 --- a/examples/images/diffusion/ldm/data/cifar10.py +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,22 +121,23 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( name, image_transforms=[], image_column="img", label_column="label", text_column="txt", - split='train', - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + split="train", + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(name, split=split) image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -144,7 +147,18 @@ def pre_process(examples): processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] - label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + label_to_text_dict = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] @@ -153,6 +167,7 @@ def pre_process(examples): ds.set_transform(pre_process) return ds + class TextOnly(Dataset): def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): """Returns only captions with dummy images""" @@ -166,7 +181,7 @@ def __init__(self, captions, output_size, image_key="image", caption_key="txt", if n_gpus > 1: # hack to make sure that all the captions appear on each gpu - repeated = [n_gpus*[x] for x in self.captions] + repeated = [n_gpus * [x] for x in self.captions] self.captions = [] [self.captions.extend(x) for x in repeated] @@ -175,10 +190,10 @@ def __len__(self): def __getitem__(self, index): dummy_im = torch.zeros(3, self.output_size, self.output_size) - dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c") return {self.image_key: dummy_im, self.caption_key: self.captions[index]} def _load_caption_file(self, filename): - with open(filename, 'rt') as f: + with open(filename, "rt") as f: captions = f.readlines() - return [x.strip('\n') for x in captions] \ No newline at end of file + return [x.strip("\n") for x in captions] diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py index 1c473f9c6965..8483e16ab23a 100644 --- a/examples/images/diffusion/ldm/data/imagenet.py +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -1,32 +1,35 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths @@ -46,9 +49,11 @@ def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) + ignore = set( + [ + "n06596364_9591.JPEG", + ] + ) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) @@ -67,20 +72,19 @@ def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE: download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): + if not os.path.exists(self.human2integer): download(URL, self.human2integer) with open(self.human2integer, "r") as f: lines = f.read().splitlines() @@ -122,11 +126,12 @@ def _load(self): if self.process_images: self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths @@ -157,8 +162,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -166,8 +170,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -179,7 +184,7 @@ def _prepare(self): print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) @@ -187,7 +192,7 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -222,8 +227,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -231,8 +235,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -242,7 +247,7 @@ def _prepare(self): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -261,18 +266,15 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,12 +298,12 @@ def __init__(self, size=None, self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert max_crop_f <= 1.0 self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow if degradation == "bsrgan": self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) @@ -311,17 +313,17 @@ def __init__(self, size=None, else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith("pil_") @@ -330,8 +332,9 @@ def __init__(self, size=None, self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) + self.degradation_process = albumentations.SmallestMaxSize( + max_size=self.LR_size, interpolation=interpolation_fn + ) def __len__(self): return len(self.base) @@ -366,8 +369,8 @@ def __getitem__(self, i): else: LR_image = self.degradation_process(image=image)["image"] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) return example @@ -379,7 +382,9 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_train_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) + dset = ImageNetTrain( + process_images=False, + ) return Subset(dset, indices) @@ -390,5 +395,7 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_val_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) + dset = ImageNetValidation( + process_images=False, + ) return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py index f5bf26c14254..e5c374aa2d51 100644 --- a/examples/images/diffusion/ldm/data/lsun.py +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -1,47 +1,49 @@ import os + import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms + # This class is used to create a dataset of images from LSUN dataset for training class LSUNBase(Dataset): - def __init__(self, - txt_file, # path to the text file containing the list of image paths - data_root, # root directory of the LSUN dataset - size=None, # the size of images to resize to - interpolation="bicubic", # interpolation method to be used while resizing - flip_p=0.5 # probability of random horizontal flipping - ): - self.data_paths = txt_file # store path to text file containing list of images - self.data_root = data_root # store path to root directory of the dataset - with open(self.data_paths, "r") as f: # open and read the text file - self.image_paths = f.read().splitlines() # read the lines of the file and store as list - self._length = len(self.image_paths) # store the number of images - + def __init__( + self, + txt_file, # path to the text file containing the list of image paths + data_root, # root directory of the LSUN dataset + size=None, # the size of images to resize to + interpolation="bicubic", # interpolation method to be used while resizing + flip_p=0.5, # probability of random horizontal flipping + ): + self.data_paths = txt_file # store path to text file containing list of images + self.data_root = data_root # store path to root directory of the dataset + with open(self.data_paths, "r") as f: # open and read the text file + self.image_paths = f.read().splitlines() # read the lines of the file and store as list + self._length = len(self.image_paths) # store the number of images + # create dictionary to hold image path information self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], } # set the image size to be resized - self.size = size + self.size = size # set the interpolation method for resizing the image - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] # randomly flip the image horizontally with a given probability self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): # return the length of dataset return self._length - def __getitem__(self, i): # get the image path for the given index @@ -52,59 +54,71 @@ def __getitem__(self, i): image = image.convert("RGB") # default to score-sde preprocessing - - img = np.array(image).astype(np.uint8) # convert image to numpy array - crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape - h, w, = img.shape[0], img.shape[1] # get the height and width of image - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] # crop the image to a square shape - - image = Image.fromarray(img) # create an image from numpy array - if self.size is not None: # if image size is provided, resize the image + + img = np.array(image).astype(np.uint8) # convert image to numpy array + crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) # get the height and width of image + img = img[ + (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2 + ] # crop the image to a square shape + + image = Image.fromarray(img) # create an image from numpy array + if self.size is not None: # if image size is provided, resize the image image = image.resize((self.size, self.size), resample=self.interpolation) - image = self.flip(image) # flip the image horizontally with the given probability - image = np.array(image).astype(np.uint8) + image = self.flip(image) # flip the image horizontally with the given probability + image = np.array(image).astype(np.uint8) example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32 - return example # return the example dictionary containing the image and its file paths + return example # return the example dictionary containing the image and its file paths + -#A dataset class for LSUN Churches training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Churches training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class. class LSUNChurchesTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) -#A dataset class for LSUN Churches validation set. + +# A dataset class for LSUN Churches validation set. # It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs + ) + -# A dataset class for LSUN Bedrooms training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Bedrooms training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. class LSUNBedroomsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) -# A dataset class for LSUN Bedrooms validation set. + +# A dataset class for LSUN Bedrooms validation set. # It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs) -# A dataset class for LSUN Cats training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. + +# A dataset class for LSUN Cats training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. class LSUNCatsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) -# A dataset class for LSUN Cats validation set. + +# A dataset class for LSUN Cats validation set. # It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py index eb5d3ea469d4..4a50a78f2dbc 100644 --- a/examples/images/diffusion/ldm/data/teyvat.py +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,23 +121,26 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( - path = "Fazzie/Teyvat", + path="Fazzie/Teyvat", image_transforms=[], image_column="image", text_column="text", - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(path, name="train") ds = ds["train"] image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.Resize((256, 256)), - transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] - ) + image_transforms.extend( + [ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")), + ] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -149,4 +154,4 @@ def pre_process(examples): return processed ds.set_transform(pre_process) - return ds \ No newline at end of file + return ds diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py index be39da9ca6da..f4efb12f28b8 100644 --- a/examples/images/diffusion/ldm/lr_scheduler.py +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -5,18 +5,20 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,13 +26,12 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps @@ -46,7 +48,7 @@ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosit self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. + self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): @@ -60,8 +62,8 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f @@ -69,8 +71,7 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -79,20 +80,20 @@ def __call__(self, n, **kwargs): class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] + ) self.last_f = f return f - diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index f0a69fe63a8c..1c54dfe74f74 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -1,29 +1,28 @@ -import torch -import lightning.pytorch as pl - -from torch import nn -from torch.nn import functional as F -from torch.nn import Identity from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder +import lightning.pytorch as pl +import torch +from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.ema import LitEma +from torch.nn import Identity +from torch.nn import functional as F class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key @@ -31,11 +30,11 @@ def __init__(self, self.decoder = Decoder(**ddconfig) self.loss = Identity() assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -43,7 +42,7 @@ def __init__(self, self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay - assert 0. < ema_decay < 1. + assert 0.0 < ema_decay < 1.0 self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") @@ -113,16 +112,30 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) @@ -137,11 +150,25 @@ def validation_step(self, batch, batch_idx): def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) @@ -150,15 +177,17 @@ def _validation_step(self, batch, batch_idx, postfix=""): def configure_optimizers(self): lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -195,7 +224,7 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -217,4 +246,3 @@ def quantize(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs): return x - diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 3cf12f093bea..73aba26c9d89 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -1,23 +1,21 @@ import os -import torch +from copy import deepcopy +from glob import glob + import lightning.pytorch as pl +import torch +from einops import rearrange +from ldm.lr_scheduler import LambdaLinearScheduler +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import default, ismap, log_txt_as_img +from natsort import natsorted from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.lr_scheduler import LambdaLinearScheduler -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} +__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel} def disabled_train(self, mode=True): @@ -27,24 +25,25 @@ def disabled_train(self, mode=True): class NoisyLatentImageClassifier(pl.LightningModule): - - def __init__(self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs): + def __init__( + self, + diffusion_path, + num_classes, + ckpt_path=None, + pool="attention", + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.0e-2, + log_steps=10, + monitor="val/loss", + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.num_classes = num_classes # get latest config of diffusion model - diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + diffusion_config = natsorted(glob(os.path.join(diffusion_path, "configs", "*-project.yaml")))[-1] self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.load_diffusion() @@ -54,10 +53,11 @@ def __init__(self, self.log_time_interval = self.diffusion_model.num_timesteps // log_steps self.log_steps = log_steps - self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ - else self.diffusion_model.cond_stage_key + self.label_key = ( + label_key if not hasattr(self.diffusion_model, "cond_stage_key") else self.diffusion_model.cond_stage_key + ) - assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + assert self.label_key is not None, "label_key neither in diffusion model nor in model.params" if self.label_key not in __models__: raise NotImplementedError() @@ -78,8 +78,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") @@ -87,7 +88,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print(f"Unexpected Keys: {unexpected}") def load_diffusion(self): - model = LatentDiffusion(**self.diffusion_config.get('params',dict())) + model = LatentDiffusion(**self.diffusion_config.get("params", dict())) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): @@ -97,14 +98,14 @@ def load_classifier(self, ckpt_path, pool): model_config = deepcopy(self.diffusion_config.params.unet_config.params) model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels model_config.out_channels = self.num_classes - if self.label_key == 'class_label': + if self.label_key == "class_label": model_config.pool = pool self.model = __models__[self.label_key](**model_config) if ckpt_path is not None: - print('#####################################################################') + print("#####################################################################") print(f'load from ckpt "{ckpt_path}"') - print('#####################################################################') + print("#####################################################################") self.init_from_ckpt(ckpt_path) @torch.no_grad() @@ -115,8 +116,9 @@ def get_x_noisy(self, x, t, noise=None): continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) # todo: make sure t+1 is correct here - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + return self.diffusion_model.q_sample( + x_start=x, t=t, noise=noise, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod + ) def forward(self, x_noisy, t, *args, **kwargs): return self.model(x_noisy, t) @@ -126,7 +128,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") x = x.to(memory_format=torch.contiguous_format).float() return x @@ -134,15 +136,15 @@ def get_input(self, batch, k): def get_conditioning(self, batch, k=None): if k is None: k = self.label_key - assert k is not None, 'Needs to provide label key' + assert k is not None, "Needs to provide label key" targets = batch[k].to(self.device) - if self.label_key == 'segmentation': - targets = rearrange(targets, 'b h w c -> b c h w') + if self.label_key == "segmentation": + targets = rearrange(targets, "b h w c -> b c h w") for down in range(self.numd): h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest") # targets = rearrange(targets,'b c h w -> b h w c') @@ -157,25 +159,21 @@ def compute_top_k(self, logits, labels, k, reduction="mean"): def on_train_epoch_start(self): # save some memory - self.diffusion_model.model.to('cpu') + self.diffusion_model.model.to("cpu") @torch.no_grad() def write_logs(self, loss, logits, targets): - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" log = {} log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) + log[f"{log_prefix}/acc@1"] = self.compute_top_k(logits, targets, k=1, reduction="mean") + log[f"{log_prefix}/acc@5"] = self.compute_top_k(logits, targets, k=5, reduction="mean") self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) - self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log("global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) def shared_step(self, batch, t=None): x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) @@ -189,7 +187,7 @@ def shared_step(self, batch, t=None): x_noisy = self.get_x_noisy(x, t) logits = self(x_noisy, t) - loss = F.cross_entropy(logits, targets, reduction='none') + loss = F.cross_entropy(logits, targets, reduction="none") self.write_logs(loss.detach(), logits.detach(), targets.detach()) @@ -201,8 +199,10 @@ def training_step(self, batch, batch_idx): return loss def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + self.noisy_acc = { + t: {"acc@1": [], "acc@5": []} + for t in range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t) + } def on_validation_start(self): self.reset_noise_accs() @@ -213,8 +213,8 @@ def validation_step(self, batch, batch_idx): for t in self.noisy_acc: _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) - self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + self.noisy_acc[t]["acc@1"].append(self.compute_top_k(logits, targets, k=1, reduction="mean")) + self.noisy_acc[t]["acc@5"].append(self.compute_top_k(logits, targets, k=5, reduction="mean")) return loss @@ -222,15 +222,12 @@ def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_scheduler: - scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) + scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict())) print("Setting up LambdaLR scheduler...") scheduler = [ - { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + {"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1} + ] return [optimizer], scheduler return optimizer @@ -239,28 +236,28 @@ def configure_optimizers(self): def log_images(self, batch, N=8, *args, **kwargs): log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x + log["inputs"] = x y = self.get_conditioning(batch) - if self.label_key == 'class_label': + if self.label_key == "class_label": y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['labels'] = y + log["labels"] = y if ismap(y): - log['labels'] = self.diffusion_model.to_rgb(y) + log["labels"] = self.diffusion_model.to_rgb(y) for step in range(self.log_steps): current_time = step * self.log_time_interval _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - log[f'inputs@t{current_time}'] = x_noisy + log[f"inputs@t{current_time}"] = x_noisy pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) - pred = rearrange(pred, 'b h w c -> b c h w') + pred = rearrange(pred, "b h w c -> b c h w") - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred) for key in log: log[key] = log[key][:N] diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py index 27ead0ea914c..a9e28792f864 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddim.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -1,11 +1,15 @@ """SAMPLING ONLY.""" -import torch import numpy as np +import torch +from ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): @@ -20,67 +24,75 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -98,35 +110,53 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -140,12 +170,12 @@ def ddim_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -154,37 +184,60 @@ def ddim_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if ucg_schedule is not None: assert len(ucg_schedule) == len(time_range) unconditional_guidance_scale = ucg_schedule[i] - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -194,13 +247,9 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F c_in = dict() for k in c: if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) elif isinstance(c, list): c_in = list() assert isinstance(unconditional_conditioning, list) @@ -217,18 +266,20 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.model.parameterization == "eps", "not implemented" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 if self.model.parameterization != "v": @@ -243,16 +294,25 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F raise NotImplementedError() # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] assert t_enc <= num_reference_steps @@ -268,33 +328,37 @@ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=No x_next = x0 intermediates = [] inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): + for i in tqdm(range(num_steps), desc="Encoding Image"): t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: + if unconditional_guidance_scale == 1.0: noise_pred = self.model.apply_model(x_next, t, c) else: assert unconditional_conditioning is not None e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) + self.model.apply_model( + torch.cat((x_next, x_next)), torch.cat((t, t)), torch.cat((unconditional_conditioning, c)) + ), + 2, + ) noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + weighted_noise_pred = ( + alphas_next[i].sqrt() * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + ) x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: + if return_intermediates and i % (num_steps // return_intermediates) == 0 and i < num_steps - 1: intermediates.append(x_next) inter_steps.append(i) elif return_intermediates and i >= num_steps - 2: intermediates.append(x_next) inter_steps.append(i) - if callback: callback(i) + if callback: + callback(i) - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} if return_intermediates: - out.update({'intermediates': intermediates}) + out.update({"intermediates": intermediates}) return x_next, out @torch.no_grad() @@ -310,13 +374,22 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -324,13 +397,20 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco total_steps = timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 842ec1371ea0..20e26256e18e 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -27,23 +27,22 @@ from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * from ldm.models.diffusion.ddim import DDIMSampler -from ldm.modules.midas.api import MiDaSInference from ldm.modules.diffusionmodules.model import * -from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.ema import LitEma from ldm.modules.encoders.modules import * +from ldm.modules.midas.api import MiDaSInference from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat from omegaconf import ListConfig from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm -__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} def disabled_train(self, mode=True): @@ -78,15 +77,15 @@ def __init__( linear_end=2e-2, cosine_s=8e-3, given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules + parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, - logvar_init=0., + logvar_init=0.0, use_fp16=True, make_it_fit=False, ucg_training=None, @@ -133,9 +132,9 @@ def __init__( if reset_ema: assert exists(ckpt) - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # if ckpt is not None: # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet) # if reset_ema: @@ -155,12 +154,14 @@ def __init__( self.linear_end = linear_end self.cosine_s = cosine_s - self.register_schedule(given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) self.loss_type = loss_type @@ -174,67 +175,73 @@ def __init__( if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', - to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + "posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) if self.parameterization == "eps": - lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + lvlb_weights = self.betas**2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) + ) elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * - (1 - self.alphas_cumprod))) + lvlb_weights = torch.ones_like( + self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + ) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager @@ -265,9 +272,11 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): del sd[k] if self.make_it_fit: n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) - for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params): + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): if not name in sd: continue old_shape = sd[name].shape @@ -302,8 +311,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd[name] = new_param - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys:\n {missing}") @@ -317,28 +327,36 @@ def q_mean_variance(self, x_start, t): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) def predict_eps_from_z_and_v(self, x_t, t, v): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) def q_posterior(self, x_start, x_t, t): - posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -350,7 +368,7 @@ def p_mean_variance(self, x, t, clip_denoised: bool): elif self.parameterization == "x0": x_recon = model_out if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -370,10 +388,10 @@ def p_sample_loop(self, shape, return_intermediates=False): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps): + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised + ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: @@ -384,28 +402,33 @@ def p_sample_loop(self, shape, return_intermediates=False): def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates + ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def get_v(self, x, noise, t): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': + if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() - elif self.loss_type == 'l2': + elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") @@ -428,17 +451,17 @@ def p_losses(self, x_start, t, noise=None): loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict @@ -452,7 +475,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") if self.use_fp16: x = x.to(memory_format=torch.contiguous_format).half() else: @@ -481,8 +504,8 @@ def training_step(self, batch, batch_idx): self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @@ -491,7 +514,7 @@ def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -501,8 +524,8 @@ def on_train_batch_end(self, *args, **kwargs): def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -521,7 +544,7 @@ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwarg for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) @@ -556,29 +579,31 @@ def configure_optimizers(self): class LatentDiffusion(DDPM): """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - use_fp16=True, - force_null_conditioning=False, - *args, - **kwargs): + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, + **kwargs, + ): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] + assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__" and not self.force_null_conditioning: conditioning_key = None super().__init__(conditioning_key=conditioning_key, *args, **kwargs) @@ -593,7 +618,7 @@ def __init__(self, if not scale_by_std: self.scale_factor = scale_factor else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.first_stage_config = first_stage_config self.cond_stage_config = cond_stage_config self.instantiate_first_stage(first_stage_config) @@ -601,9 +626,9 @@ def __init__(self, self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # self.restarted_from_ckpt = False # if self.ckpt is not None: # self.init_from_ckpt(self.ckpt, self.ignore_keys) @@ -630,15 +655,18 @@ def configure_sharded_model(self) -> None: if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - self.register_schedule(given_betas=self.given_betas, - beta_schedule=self.beta_schedule, - timesteps=self.timesteps, - linear_start=self.linear_start, - linear_end=self.linear_end, - cosine_s=self.cosine_s) + self.register_schedule( + given_betas=self.given_betas, + beta_schedule=self.beta_schedule, + timesteps=self.timesteps, + linear_start=self.linear_start, + linear_end=self.linear_end, + cosine_s=self.cosine_s, + ) self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: @@ -654,20 +682,29 @@ def configure_sharded_model(self) -> None: if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - def make_cond_schedule(self,): + def make_cond_schedule( + self, + ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx): # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert self.scale_factor == 1.0, "rather not use custom rescaling and std-rescaling simultaneously" # set rescale weight to 1./std of encodings rank_zero_info("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) @@ -675,17 +712,19 @@ def on_train_batch_start(self, batch, batch_idx): encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) rank_zero_info(f"setting self.scale_factor to {self.scale_factor}") rank_zero_info("### USING STD-RESCALING ###") - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 @@ -718,15 +757,16 @@ def instantiate_cond_stage(self, config): model = FrozenOpenCLIPEmbedder(**config) self.cond_stage_model = model - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + def _get_denoise_row_from_list(self, samples, desc="", force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( - self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization) + ) n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -741,7 +781,7 @@ def get_first_stage_encoding(self, encoder_posterior): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + if hasattr(self.cond_stage_model, "encode") and callable(self.cond_stage_model.encode): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() @@ -784,14 +824,17 @@ def get_weighting(self, h, w, Ly, Lx, device): if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -809,35 +852,39 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load on fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, - padding=0, - stride=(stride[0] * uf, stride[1] * uf)) + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) elif df > 1 and uf == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, - padding=0, - stride=(stride[0] // df, stride[1] // df)) + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) else: @@ -846,15 +893,17 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load on return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - return_x=False): + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False, + ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -866,9 +915,9 @@ def get_input(self, if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', "txt"]: + if cond_key in ["caption", "coordinates_bbox", "txt"]: xc = batch[cond_key] - elif cond_key in ['class_label', 'cls']: + elif cond_key in ["class_label", "cls"]: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) @@ -887,14 +936,14 @@ def get_input(self, if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} + c = {"pos_x": pos_x, "pos_y": pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) @@ -912,9 +961,9 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() + z = rearrange(z, "b h w c -> b c h w").contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() @@ -932,7 +981,7 @@ def forward(self, x, c, *args, **kwargs): assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option + if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) @@ -944,7 +993,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): else: if not isinstance(cond, list): cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) @@ -955,8 +1004,9 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ @@ -978,7 +1028,7 @@ def p_losses(self, x_start, cond, t, noise=None): model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" if self.parameterization == "x0": target = x_start @@ -990,36 +1040,38 @@ def p_losses(self, x_start, cond, t, noise=None): raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) return loss, loss_dict - def p_mean_variance(self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None): + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -1038,7 +1090,7 @@ def p_mean_variance(self, raise NotImplementedError() if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -1050,29 +1102,33 @@ def p_mean_variance(self, return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None): + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs @@ -1082,7 +1138,7 @@ def p_sample(self, model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) @@ -1095,23 +1151,25 @@ def p_sample(self, return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def progressive_denoising(self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None): + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps @@ -1128,40 +1186,47 @@ def progressive_denoising(self, if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed(range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img, x0_partial = self.p_sample(img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) @@ -1172,21 +1237,22 @@ def progressive_denoising(self, return img, intermediates @torch.no_grad() - def p_sample_loop(self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None): - + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device @@ -1202,24 +1268,27 @@ def p_sample_loop(self, if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if mask is not None: assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) @@ -1233,37 +1302,43 @@ def p_sample_loop(self, return img @torch.no_grad() - def sample(self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs): + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): @@ -1295,41 +1370,45 @@ def get_unconditional_conditioning(self, batch_size, null_label=None): return self.get_learned_conditioning(xc) else: raise NotImplementedError("todo") - if isinstance(c, list): # in case the encoder gives us a list + if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): - c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) else: - c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=50, - ddim_eta=0., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1341,10 +1420,10 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', "cls"]: + elif self.cond_stage_key in ["class_label", "cls"]: try: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc except KeyError: # probably no "human_label" in batch pass @@ -1359,26 +1438,24 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1386,16 +1463,16 @@ def log_images(self, denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - quantize_denoised=True) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) @@ -1423,38 +1500,30 @@ def log_images(self, b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint - mask = 1. - mask + mask = 1.0 - mask with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1472,10 +1541,11 @@ def configure_optimizers(self): rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: - rank_zero_info('Diffusion model optimizing logvar') + rank_zero_info("Diffusion model optimizing logvar") params.append(self.logvar) from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) # opt = torch.optim.AdamW(params, lr=lr) @@ -1483,7 +1553,7 @@ def configure_optimizers(self): scheduler = LambdaLinearScheduler(**self.scheduler_config) rank_zero_info("Setting up LambdaLR scheduler...") - scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + scheduler = [{"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1}] return [opt], scheduler return opt @@ -1493,45 +1563,44 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.diffusion_model = UNetModel(**diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm", "hybrid-adm", "crossattn-adm"] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): if self.conditioning_key is None: out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': + elif self.conditioning_key == "concat": xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': + elif self.conditioning_key == "crossattn": if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': + elif self.conditioning_key == "hybrid": xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'hybrid-adm': + elif self.conditioning_key == "hybrid-adm": assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc, y=c_adm) - elif self.conditioning_key == 'crossattn-adm': + elif self.conditioning_key == "crossattn-adm": assert c_adm is not None cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc, y=c_adm) - elif self.conditioning_key == 'adm': + elif self.conditioning_key == "adm": cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: @@ -1541,7 +1610,6 @@ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=N class LatentUpscaleDiffusion(LatentDiffusion): - def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): super().__init__(*args, **kwargs) # assumes that neither the cond_stage nor the low_scale_model contain trainable params @@ -1562,14 +1630,16 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) x_low = batch[self.low_scale_key][:bs] - x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = rearrange(x_low, "b h w c -> b c h w") if self.use_fp16: x_low = x_low.to(memory_format=torch.contiguous_format).half() else: @@ -1577,7 +1647,7 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): zx, noise_level = self.low_scale_model(x_low) if self.noise_level_key is not None: # get noise level from batch instead, e.g. when extracting a custom noise level for bsr - raise NotImplementedError('TODO') + raise NotImplementedError("TODO") all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} if log_mode: @@ -1587,29 +1657,30 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): return z, all_conds @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, - self.first_stage_key, - bs=N, - log_mode=True) + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1623,9 +1694,9 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): @@ -1637,26 +1708,24 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1673,9 +1742,9 @@ def log_images(self, if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] - elif k == "c_adm": # todo: only run with text-based guidance? + elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) - #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] elif isinstance(c[k], list): uc[k] = [c[k][i] for i in range(len(c[k]))] @@ -1697,9 +1766,9 @@ def log_images(self, if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1708,21 +1777,24 @@ def log_images(self, class LatentFinetuneDiffusion(LatentDiffusion): """ - Basis for different finetunas, such as inpainting or depth2image - To disable finetuning mode, set finetune_keys to None + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None """ def __init__( - self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight"), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, - **kwargs): + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): ckpt = kwargs.pop("ckpt", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) @@ -1732,7 +1804,7 @@ def __init__( self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end if exists(self.finetune_keys): - assert exists(ckpt), 'can only finetune from a given checkpoint' + assert exists(ckpt), "can only finetune from a given checkpoint" if exists(ckpt): self.init_from_ckpt(ckpt, ignore_keys) @@ -1755,13 +1827,14 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): rank_zero_info( f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" ) - new_entry = torch.zeros_like(param) # zero init - assert exists(new_entry), 'did not find matching parameter to modify' - new_entry[:, :self.keep_dims, ...] = sd[k] + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] sd[k] = new_entry - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys: {missing}") @@ -1769,23 +1842,25 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): rank_zero_info(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None @@ -1803,16 +1878,16 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if not (self.c_concat_log_start is None and self.c_concat_log_end is None): - log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start : self.c_concat_log_end]) if plot_diffusion_rows: # get diffusion row @@ -1820,29 +1895,28 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1856,10 +1930,7 @@ def log_images(self, uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( - cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, + cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, @@ -1878,7 +1949,7 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): can either run as pure inpainting model (only concat mode) or with mixed conditionings, e.g. mask as concat and text via cross-attn. To disable finetuning mode, set finetune_keys to None - """ + """ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): super().__init__(concat_keys, *args, **kwargs) @@ -1888,21 +1959,23 @@ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="maske @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) c_cat = list() for ck in self.concat_keys: if self.use_fp16: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).half() else: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -1921,8 +1994,9 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) - log["masked_image"] = rearrange(args[0]["masked_image"], - 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + log["masked_image"] = ( + rearrange(args[0]["masked_image"], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() + ) return log @@ -1939,13 +2013,15 @@ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwarg @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1963,10 +2039,10 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs align_corners=False, ) - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, - dim=[1, 2, 3], - keepdim=True) - cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax( + cc, dim=[1, 2, 3], keepdim=True + ) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} @@ -1978,24 +2054,21 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) depth = self.depth_model(args[0][self.depth_stage_key]) - depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ - torch.amax(depth, dim=[1, 2, 3], keepdim=True) - log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), torch.amax( + depth, dim=[1, 2, 3], keepdim=True + ) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 return log class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ - condition on low-res image (and optionally on some spatial noise augmentation) + condition on low-res image (and optionally on some spatial noise augmentation) """ - def __init__(self, - concat_keys=("lr",), - reshuffle_patch_size=None, - low_scale_config=None, - low_scale_key=None, - *args, - **kwargs): + def __init__( + self, concat_keys=("lr",), reshuffle_patch_size=None, low_scale_config=None, low_scale_key=None, *args, **kwargs + ): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None @@ -2015,13 +2088,15 @@ def instantiate_low_stage(self, config): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -2030,13 +2105,15 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs noise_level = None for ck in self.concat_keys: cc = batch[ck] - cc = rearrange(cc, 'b h w c -> b c h w') + cc = rearrange(cc, "b h w c -> b c h w") if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, - 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, - p2=self.reshuffle_patch_size) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -2055,5 +2132,5 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) - log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py index 7427f38c0753..f56611cb5fb3 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -1 +1 @@ -from .sampler import DPMSolverSampler \ No newline at end of file +from .sampler import DPMSolverSampler diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3ce0b..66063320ec78 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -1,17 +1,17 @@ -import torch -import torch.nn.functional as F import math + +import torch from tqdm import tqdm class NoiseScheduleVP: def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, ): """Create a wrapper class for the forward SDE (VP type). *** @@ -70,50 +70,63 @@ def __init__( >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) """ - if schedule not in ['discrete', 'linear', 'cosine']: + if schedule not in ["discrete", "linear", "cosine"]: raise ValueError( "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) + schedule + ) + ) self.schedule = schedule - if schedule == 'discrete': + if schedule == "discrete": if betas is not None: log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) else: assert alphas_cumprod is not None log_alphas = 0.5 * torch.log(alphas_cumprod) self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ) else: self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) self.schedule = schedule - if schedule == 'cosine': + if schedule == "cosine": # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. self.T = 0.9946 else: - self.T = 1. + self.T = 1.0 def marginal_log_mean_coeff(self, t): """ Compute log(alpha_t) of a given continuous-time label t in [0, T]. """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_t @@ -127,48 +140,56 @@ def marginal_std(self, t): """ Compute sigma_t of a given continuous-time label t in [0, T]. """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) def marginal_lambda(self, t): """ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. """ log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) return log_mean_coeff - log_std def inverse_lambda(self, lamb): """ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) return t.reshape((-1,)) else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + t_fn = ( + lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) t = t_fn(log_alpha) return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to @@ -249,8 +270,8 @@ def get_model_input_time(t_continuous): For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. For continuous-time DPMs, we just use `t_continuous`. """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 else: return t_continuous @@ -302,7 +323,7 @@ def model_fn(x, t_continuous): noise = noise_pred_fn(x, t_continuous) return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: + if guidance_scale == 1.0 or unconditional_condition is None: return noise_pred_fn(x, t_continuous, cond=condition) else: x_in = torch.cat([x] * 2) @@ -317,7 +338,7 @@ def model_fn(x, t_continuous): class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). @@ -387,20 +408,21 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device): Returns: A pytorch tensor of the time steps, with the shape (N + 1,). """ - if skip_type == 'logSNR': + if skip_type == "logSNR": lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': + elif skip_type == "time_uniform": return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': + elif skip_type == "time_quadratic": t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -435,29 +457,57 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [ + 3, + ] * ( + K - 1 + ) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [ + 3, + ] * ( + K - 1 + ) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [ + 2, + ] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [ + 2, + ] * ( + K - 1 + ) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [ + 1, + ] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': + if skip_type == "logSNR": # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ) + ).to(device) + ] return timesteps_outer, orders def denoise_to_zero_fn(self, x, s): @@ -491,12 +541,9 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) + x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t else: @@ -504,16 +551,17 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s ) if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpm_solver" + ): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. Args: @@ -529,7 +577,7 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 @@ -539,8 +587,11 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) @@ -550,23 +601,19 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * (model_s1 - model_s) ) else: phi_11 = torch.expm1(r1 * h) @@ -575,29 +622,39 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * (model_s1 - model_s) ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} + return x_t, {"model_s": model_s, "model_s1": model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpm_solver", + ): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: @@ -616,12 +673,12 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: - r1 = 1. / 3. + r1 = 1.0 / 3.0 if r2 is None: - r2 = 2. / 3. + r2 = 2.0 / 3.0 ns = self.noise_schedule dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) @@ -630,93 +687,98 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo lambda_s2 = lambda_s + r2 * h s1 = ns.inverse_lambda(lambda_s1) s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) if self.predict_x0: phi_11 = torch.expm1(-r1 * h) phi_12 = torch.expm1(-r2 * h) phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 ) else: phi_11 = torch.expm1(r1 * h) phi_12 = torch.expm1(r2 * h) phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} else: return x_t @@ -733,14 +795,17 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule dims = x.dim() model_prev_1, model_prev_0 = model_prev_list t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -748,36 +813,36 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) if self.predict_x0: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0 ) else: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0 ) return x_t - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Args: @@ -794,8 +859,12 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, dims = x.dim() model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -804,28 +873,29 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2) D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1) if self.predict_x0: x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2 ) else: x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2 ) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpm_solver", r1=None, r2=None + ): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. Args: @@ -844,15 +914,17 @@ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Args: @@ -875,8 +947,9 @@ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver" + ): """ The adaptive step size solver based on singlestep DPM-Solver. Args: @@ -906,17 +979,17 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol if order == 2: r1 = 0.5 lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -926,20 +999,31 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): + if torch.all(E <= 1.0): x = x_higher s = t x_prev = x_lower lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) nfe += order - print('adaptive solver nfe', nfe) + print("adaptive solver nfe", nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type="time_uniform", + method="singlestep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpm_solver", + atol=0.0078, + rtol=0.05, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. ===================================================== @@ -1034,14 +1118,15 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time Returns: x_end: A pytorch tensor. The approximated solution at time `t_end`. """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start device = x.device - if method == 'adaptive': + if method == "adaptive": with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": assert steps >= order timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps @@ -1052,8 +1137,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # Init the first `order` values by lower order multistep DPM-Solver. for init_order in tqdm(range(1, order), desc="DPM init order"): vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type + ) model_prev_list.append(self.model_fn(x, vec_t)) t_prev_list.append(vec_t) # Compute the remaining values by `order`-th order multistep DPM-Solver. @@ -1063,8 +1149,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time step_order = min(order, steps + 1 - step) else: step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type + ) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1072,20 +1159,22 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # We do not need to evaluate the final model value. if step < steps: model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": K = steps // order - orders = [order, ] * K + orders = [ + order, + ] * K timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for i, order in enumerate(orders): t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device + ) lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) h = lambda_inner[-1] - lambda_inner[0] @@ -1101,6 +1190,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # other utility functions ############################################################# + def interpolate_fn(x, xp, yp): """ A piecewise linear function y = f(x), using xp and yp as keypoints. @@ -1122,7 +1212,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(1, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) @@ -1132,7 +1224,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(0, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) @@ -1151,4 +1245,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf367..55dac8555e5f 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,13 +1,9 @@ """SAMPLING ONLY.""" import torch -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver +from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} +MODEL_TYPES = {"eps": "noise", "v": "v"} class DPMSolverSampler(object): @@ -15,7 +11,7 @@ def __init__(self, model, **kwargs): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -24,30 +20,31 @@ def register_buffer(self, name, attr): setattr(self, name, attr) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -61,7 +58,7 @@ def sample(self, C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}") device = self.model.betas.device if x_T is None: @@ -69,7 +66,7 @@ def sample(self, else: img = x_T - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), @@ -82,6 +79,8 @@ def sample(self, ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample( + img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True + ) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py index 7002a365d271..b2b3f032e491 100644 --- a/examples/images/diffusion/ldm/models/diffusion/plms.py +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -1,12 +1,10 @@ """SAMPLING ONLY.""" -import torch import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import torch from ldm.models.diffusion.sampling_util import norm_thresholding +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from tqdm import tqdm class PLMSSampler(object): @@ -22,65 +20,72 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -94,34 +99,51 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -135,12 +157,12 @@ def plms_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) old_eps = [] for i, step in enumerate(iterator): @@ -151,38 +173,64 @@ def plms_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -199,7 +247,9 @@ def get_model_output(x, t): alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas def get_x_prev_and_pred_x0(e_t, index): @@ -207,7 +257,7 @@ def get_x_prev_and_pred_x0(e_t, index): a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -216,9 +266,9 @@ def get_x_prev_and_pred_x0(e_t, index): if dynamic_threshold is not None: pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py index 7eff02be6d7c..a4681368112b 100644 --- a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -1,13 +1,9 @@ -import torch -import numpy as np - - def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions. From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -19,4 +15,4 @@ def norm_thresholding(x0, value): def spatial_norm_thresholding(x0, value): # b c h w s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file + return x0 * (value / s) diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py index d504d939f6a0..f3c385e5138f 100644 --- a/examples/images/diffusion/ldm/modules/attention.py +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -1,17 +1,17 @@ -from inspect import isfunction import math +from inspect import isfunction +from typing import Any, Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any - from ldm.modules.diffusionmodules.util import checkpoint - +from torch import einsum, nn try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -22,7 +22,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -54,20 +54,13 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -92,26 +85,10 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -121,41 +98,38 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -165,22 +139,22 @@ def forward(self, x, context=None, mask=None): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -188,8 +162,10 @@ class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -236,20 +212,36 @@ def forward(self, x, context=None, mask=None): class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention + "softmax-xformers": MemoryEfficientCrossAttention, } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): super().__init__() attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -274,10 +266,19 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + ): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] @@ -285,25 +286,26 @@ def __init__(self, in_channels, n_heads, d_head, inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear @@ -317,15 +319,14 @@ def forward(self, x, context=None): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py index fb088db58919..7ed8d98a44ad 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -17,6 +17,7 @@ try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -39,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad + if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb @@ -54,7 +55,6 @@ def Normalize(in_channels, num_groups=32): class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -69,7 +69,6 @@ def forward(self, x): class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -88,7 +87,6 @@ def forward(self, x): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels @@ -133,7 +131,6 @@ def forward(self, x, temb): class AttnBlock(nn.Module): - def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -154,16 +151,16 @@ def forward(self, x): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) @@ -173,9 +170,9 @@ def forward(self, x): class MemoryEfficientAttnBlock(nn.Module): """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation """ # @@ -199,34 +196,41 @@ def forward(self, x): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( - lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C). - contiguous(), + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) out = self.proj_out(out) return x + out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", - "none"], f'attn_type {attn_type} unknown' + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": attn_type = "vanilla-xformers" if attn_type == "vanilla": @@ -245,21 +249,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla"): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -274,10 +279,12 @@ def __init__(self, if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) @@ -292,10 +299,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -309,15 +316,13 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -330,10 +335,13 @@ def __init__(self, if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( - ResnetBlock(in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -343,14 +351,14 @@ def __init__(self, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -401,23 +409,24 @@ def get_last_layer(self): class Encoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -442,10 +451,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -459,23 +468,19 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): # timestep embedding @@ -506,24 +511,25 @@ def forward(self, x): class Decoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -537,9 +543,9 @@ def __init__(self, self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) + (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) @@ -548,15 +554,13 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -566,10 +570,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -579,14 +583,14 @@ def __init__(self, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -622,17 +626,18 @@ def forward(self, z): class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True) - ]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -651,7 +656,6 @@ def forward(self, x): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling @@ -659,7 +663,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -667,10 +671,10 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): res_block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -696,21 +700,24 @@ def forward(self, x): class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) - self.res_block1 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.conv_out = nn.Conv2d( mid_channels, @@ -722,9 +729,9 @@ def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, - size=(int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)))) + x = torch.nn.functional.interpolate( + x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))) + ) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -733,37 +740,42 @@ def forward(self, x): class MergedRescaleEncoder(nn.Module): - - def __init__(self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) def forward(self, x): x = self.encoder(x) @@ -772,36 +784,41 @@ def forward(self, x): class MergedRescaleDecoder(nn.Module): - - def __init__(self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth) + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) def forward(self, x): x = self.rescaler(x) @@ -810,27 +827,27 @@ def forward(self, x): class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1. + (out_size % in_size) + factor_up = 1.0 + (out_size % in_size) rank_zero_info( f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" ) - self.rescaler = LatentRescaler(factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + self.rescaler = LatentRescaler( + factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) def forward(self, x): x = self.rescaler(x) @@ -839,14 +856,14 @@ def forward(self, x): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: rank_zero_info( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index cd639d936046..614fe510f20e 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,21 +1,20 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F - +from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.util import ( + avg_pool_nd, checkpoint, conv_nd, linear, - avg_pool_nd, - zero_module, normalization, timestep_embedding, + zero_module, ) -from ldm.modules.attention import SpatialTransformer from ldm.util import exists @@ -23,6 +22,7 @@ def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -41,7 +41,7 @@ def __init__( output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -108,25 +108,25 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' + "Learned 2x upsampling without padding" + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -139,7 +139,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -147,9 +147,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -225,17 +223,13 @@ def __init__( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -246,10 +240,7 @@ def forward(self, x, emb): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -311,8 +302,10 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -339,7 +332,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,9 +356,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -460,10 +451,10 @@ def __init__( use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, disable_self_attentions=None, num_attention_blocks=None, @@ -472,11 +463,16 @@ def __init__( ): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -484,10 +480,10 @@ def __init__( num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels @@ -497,19 +493,25 @@ def __init__( self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -540,11 +542,7 @@ def __init__( raise ValueError() self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -571,7 +569,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -586,10 +584,17 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -610,9 +615,7 @@ def __init__( down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -626,7 +629,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -643,11 +646,18 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ), ResBlock( ch, time_embed_dim, @@ -682,7 +692,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -697,10 +707,17 @@ def __init__( num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: @@ -730,10 +747,10 @@ def __init__( ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -751,7 +768,7 @@ def convert_to_fp32(self): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py index 03816662098c..82cc2157ca68 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -1,8 +1,8 @@ -import torch -import torch.nn as nn -import numpy as np from functools import partial +import numpy as np +import torch +import torch.nn as nn from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ldm.util import default @@ -14,37 +14,41 @@ def __init__(self, noise_schedule_config=None): if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + def register_schedule( + self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 + ): + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def forward(self, x): return x, None @@ -76,6 +80,3 @@ def forward(self, x, noise_level=None): assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level - - - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py index 36b4a171b6c2..aed1b061323a 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -8,7 +8,6 @@ # thanks! import math -import os import numpy as np import torch @@ -19,10 +18,10 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 elif schedule == "cosine": - timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] @@ -32,18 +31,18 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': + if ddim_discr_method == "uniform": c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) + elif ddim_discr_method == "quad": + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') @@ -51,7 +50,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + print(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out @@ -63,9 +62,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) return sigmas, alphas, alphas_prev @@ -106,6 +107,7 @@ def checkpoint(func, inputs, params, flag): """ if flag: from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) # args = tuple(inputs) + tuple(params) # return CheckpointFunction.apply(func, len(inputs), *args) @@ -114,7 +116,6 @@ def checkpoint(func, inputs, params, flag): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function @@ -123,7 +124,7 @@ def forward(ctx, run_function, length, *args): ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled() + "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) @@ -132,8 +133,7 @@ def forward(ctx, run_function, length, *args): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. @@ -162,14 +162,15 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / - half).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: - embedding = repeat(timesteps, 'b -> b d', d=dim) + embedding = repeat(timesteps, "b -> b d", d=dim) return embedding @@ -210,13 +211,11 @@ def normalization(channels): # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): - def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): - def forward(self, x): return super().forward(x.float()).type(x.dtype) @@ -255,7 +254,6 @@ def avg_pool_nd(dims, *args, **kwargs): class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) @@ -264,7 +262,7 @@ def __init__(self, c_concat_config, c_crossattn_config): def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} def noise_like(shape, device, repeat=False): diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py index f2b8ef901130..b5f3b1ad48da 100644 --- a/examples/images/diffusion/ldm/modules/distributions/distributions.py +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: @@ -38,25 +38,25 @@ def sample(self): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean @@ -78,15 +78,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py index bded25019b9b..c3863269675e 100644 --- a/examples/images/diffusion/ldm/modules/ema.py +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -6,17 +6,18 @@ class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') + raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates - else torch.tensor(-1, dtype=torch.int)) + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) + ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers - s_name = name.replace('.', '') + s_name = name.replace(".", "") self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) @@ -24,7 +25,7 @@ def __init__(self, model, decay=0.9999, use_num_upates=True): def reset_num_updates(self): del self.num_updates - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py index 4edd5496b9e6..58bff2382c47 100644 --- a/examples/images/diffusion/ldm/modules/encoders/modules.py +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -1,11 +1,9 @@ +import open_clip import torch import torch.nn as nn +from ldm.util import count_params from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import open_clip -from ldm.util import default, count_params +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class AbstractEncoder(nn.Module): @@ -17,13 +15,12 @@ def encode(self, *args, **kwargs): class IdentityEncoder(AbstractEncoder): - def encode(self, x): return x class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) @@ -35,9 +32,9 @@ def forward(self, batch, key=None, disable_dropout=False): key = self.key # this is for use in crossattn c = batch[key][:, None] - if self.ucg_rate > 0. and not disable_dropout: - mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,24 +54,34 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) @@ -87,13 +94,18 @@ def encode(self, text): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) @@ -110,15 +122,22 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -135,16 +154,19 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ + LAYERS = [ - #"pooled", + # "pooled", "last", - "penultimate" + "penultimate", ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): + + def __init__( + self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last" + ): super().__init__() assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device("cpu"), pretrained=version) del model.visual self.model = model @@ -179,7 +201,7 @@ def encode_with_transformer(self, text): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -194,13 +216,21 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) def encode(self, text): return self(text) @@ -209,5 +239,3 @@ def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py index 32ef56169978..879b2aa099b6 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -10,33 +10,32 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -142,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -208,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -226,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -314,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -330,8 +328,8 @@ def add_blur(img, sf=4): l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -366,6 +364,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -374,11 +373,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -392,23 +391,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -418,7 +417,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -428,10 +427,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -452,18 +451,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -475,7 +475,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -487,13 +486,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -541,18 +543,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -564,7 +568,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -576,13 +579,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -609,7 +615,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {"image": image} return example @@ -630,11 +636,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") if use_sharp: img = add_sharpening(img) @@ -686,11 +692,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) else: - print('check the shuffle!') + print("check the shuffle!") # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), interpolation=random.choice([1, 2, 3]) + ) # add final JPEG compression noise img = add_JPEG_noise(img) @@ -701,30 +708,30 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc return img, hq -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py index 808c7f882cb7..cf3f83f0c011 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util """ # -------------------------------------------- @@ -25,17 +24,18 @@ # -------------------------------------------- """ + def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -53,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -62,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -73,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -125,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -141,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -156,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -207,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -225,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -252,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -274,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -313,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -325,16 +324,16 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -369,6 +368,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -377,11 +377,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -395,23 +395,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -421,7 +421,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -431,10 +431,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -455,18 +455,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -478,7 +479,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -490,13 +490,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -544,18 +547,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -567,7 +572,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -582,13 +586,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,16 +624,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): image = add_JPEG_noise(image) image = util.single2uint(image) if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + image = cv2.resize( + image, (w1, h1), interpolation=cv2.INTER_CUBIC + ) # todo: random, as above? want to condition on it then example = {"image": image} return example - - -if __name__ == '__main__': +if __name__ == "__main__": print("hey") - img = util.imread_uint('utils/test.png', 3) + img = util.imread_uint("utils/test.png", 3) img = img[:448, :448] h = img.shape[0] // 4 print("resizing to", h) @@ -638,14 +645,17 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[ + "image" + ] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py index 0175f155ad90..71fae1084b61 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -1,18 +1,20 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -''' +""" # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 @@ -20,10 +22,10 @@ # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- -''' +""" -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tif"] def is_image_file(filename): @@ -31,12 +33,12 @@ def is_image_file(filename): def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') + return datetime.now().strftime("%y%m%d-%H%M%S") def imshow(x, title=None, cbar=False, figsize=None): plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") if title: plt.title(title) if cbar: @@ -44,24 +46,24 @@ def imshow(x, title=None, cbar=False, figsize=None): plt.show() -def surf(Z, cmap='rainbow', figsize=None): +def surf(Z, cmap="rainbow", figsize=None): plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') + ax3 = plt.axes(projection="3d") w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() -''' +""" # -------------------------------------------- # get image pathes # -------------------------------------------- -''' +""" def get_image_paths(dataroot): @@ -72,37 +74,37 @@ def get_image_paths(dataroot): def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) + assert images, "{:s} has no valid image file".format(path) return images -''' +""" # -------------------------------------------- -# split large images into small images +# split large images into small images # -------------------------------------------- -''' +""" def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) @@ -118,7 +120,7 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join(os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png") cv2.imwrite(new_path, img) @@ -139,15 +141,16 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + # if original_dataroot == taget_dataroot: + # del img_path + -''' +""" # -------------------------------------------- # makedir # -------------------------------------------- -''' +""" def mkdir(path): @@ -165,18 +168,18 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) os.rename(path, new_name) os.makedirs(path) -''' +""" # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -206,6 +209,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,7 +217,6 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- @@ -221,7 +224,7 @@ def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. + img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels @@ -230,7 +233,7 @@ def read_img(path): return img -''' +""" # -------------------------------------------- # image format conversion # -------------------------------------------- @@ -238,7 +241,7 @@ def read_img(path): # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -247,23 +250,19 @@ def read_img(path): def uint2single(img): - - return np.float32(img/255.) + return np.float32(img / 255.0) def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): - - return np.float32(img/65535.) + return np.float32(img / 65535.0) def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- @@ -275,14 +274,14 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) # convert 2/3/4-dimensional torch tensor to uint @@ -290,7 +289,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -316,6 +315,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -340,11 +340,11 @@ def single42tensor4(img): # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' + """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' + """ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() @@ -358,15 +358,14 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) -''' +""" # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- @@ -374,12 +373,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- -''' +""" def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -399,8 +397,7 @@ def augment_img(img, mode=0): def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -420,8 +417,7 @@ def augment_img_tensor4(img, mode=0): def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: @@ -484,11 +480,11 @@ def _augment(img): return [_augment(img) for img in img_list] -''' +""" # -------------------------------------------- # modcrop and shave # -------------------------------------------- -''' +""" def modcrop(img_in, scale): @@ -497,13 +493,13 @@ def modcrop(img_in, scale): if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] + img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] + img = img[: H - H_r, : W - W_r, :] else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) return img @@ -511,11 +507,11 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border : h - border, border : w - border] return img -''' +""" # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): @@ -523,96 +519,99 @@ def shave(img_in, border=0): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- -''' +""" def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr + """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb + """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.matmul( + img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]] + ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr + """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray + if in_c == 3 and tar_type == "gray": # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y + elif in_c == 3 and tar_type == "y": # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list -''' +""" # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -620,19 +619,19 @@ def channel_convert(in_c, tar_type, img_list): # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) + mse = np.mean((img1 - img2) ** 2) if mse == 0: - return float('inf') + return float("inf") return 20 * math.log10(255.0 / math.sqrt(mse)) @@ -640,17 +639,17 @@ def calculate_psnr(img1, img2, border=0): # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): - '''calculate SSIM + """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,17 +657,17 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: - raise ValueError('Wrong input image dimensions.') + raise ValueError("Wrong input image dimensions.") def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -684,16 +683,15 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() -''' +""" # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- -''' +""" # matlab 'imresize' function, now only support 'bicubic' @@ -701,8 +699,9 @@ def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): @@ -729,8 +728,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -773,7 +773,7 @@ def imresize(img, scale, antialiasing=True): in_C, in_H, in_W = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -782,9 +782,11 @@ def imresize(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -805,7 +807,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -827,7 +829,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2 @@ -848,7 +850,7 @@ def imresize_np(img, scale, antialiasing=True): in_H, in_W, in_C = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -857,9 +859,11 @@ def imresize_np(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -880,7 +884,7 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -902,15 +906,15 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2.numpy() -if __name__ == '__main__': - print('---') +if __name__ == "__main__": + print("---") # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py index b58ebbffd942..6619f515fa0e 100644 --- a/examples/images/diffusion/ldm/modules/midas/api.py +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -3,13 +3,11 @@ import cv2 import torch import torch.nn as nn -from torchvision.transforms import Compose - from ldm.modules.midas.midas.dpt_depth import DPTDepthModel from ldm.modules.midas.midas.midas_net import MidasNet from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - +from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize +from torchvision.transforms import Compose ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -98,18 +96,20 @@ def load_model(model_type): model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) + model = MidasNet_small( + model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={"expand": True}, + ) net_w, net_h = 256, 256 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") @@ -135,11 +135,7 @@ def load_model(model_type): class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] + MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", @@ -149,7 +145,7 @@ class MiDaSInference(nn.Module): def __init__(self, model_type): super().__init__() - assert (model_type in self.MODEL_TYPES_ISL) + assert model_type in self.MODEL_TYPES_ISL model, _ = load_model(model_type) self.model = model self.model.train = disabled_train @@ -167,4 +163,3 @@ def forward(self, x): ) assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) return prediction - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py index 5cf430239b47..5c2e0e93b049 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -8,7 +8,7 @@ def load(self, path): Args: path (str): file path """ - parameters = torch.load(path, map_location=torch.device('cpu')) + parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py index 2145d18fa980..154de57cd2e8 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -1,18 +1,22 @@ import torch import torch.nn as nn -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): +from .vit import _make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384 + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitl16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) @@ -27,22 +31,20 @@ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, ex [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitb16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") assert False - + return pretrained, scratch @@ -53,11 +55,11 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape - if expand==True: + if expand == True: out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups @@ -77,10 +79,7 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable + "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) @@ -88,21 +87,17 @@ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): def _make_efficientnet_backbone(effnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) + pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained - + def _make_resnet_backbone(resnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 @@ -116,10 +111,8 @@ def _make_pretrained_resnext101_wsl(use_pretrained): return _make_resnet_backbone(resnet) - class Interpolate(nn.Module): - """Interpolation module. - """ + """Interpolation module.""" def __init__(self, scale_factor, mode, align_corners=False): """Init. @@ -145,16 +138,13 @@ def forward(self, x): tensor: interpolated data """ - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class ResidualConvUnit(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features): """Init. @@ -164,13 +154,9 @@ def __init__(self, features): """ super().__init__() - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) @@ -192,8 +178,7 @@ def forward(self, x): class FeatureFusionBlock(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features): """Init. @@ -219,18 +204,13 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) return output - - class ResidualConvUnit_custom(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. @@ -242,17 +222,13 @@ def __init__(self, features, activation, bn): self.bn = bn - self.groups=1 + self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - if self.bn==True: + if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) @@ -269,15 +245,15 @@ def forward(self, x): Returns: tensor: output """ - + out = self.activation(x) out = self.conv1(out) - if self.bn==True: + if self.bn == True: out = self.bn1(out) - + out = self.activation(out) out = self.conv2(out) - if self.bn==True: + if self.bn == True: out = self.bn2(out) if self.groups > 1: @@ -289,8 +265,7 @@ def forward(self, x): class FeatureFusionBlock_custom(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): """Init. @@ -303,18 +278,18 @@ def __init__(self, features, activation, deconv=False, bn=False, expand=False, a self.deconv = deconv self.align_corners = align_corners - self.groups=1 + self.groups = 1 self.expand = expand out_features = features - if self.expand==True: - out_features = features//2 - + if self.expand == True: + out_features = features // 2 + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - + self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): @@ -332,11 +307,8 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py index 4e9aab5d2767..74871e8b1fce 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -1,15 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit def _make_fusion_block(features, use_bn): @@ -33,7 +26,6 @@ def __init__( channels_last=False, use_bn=False, ): - super(DPT, self).__init__() self.channels_last = channels_last @@ -48,7 +40,7 @@ def __init__( self.pretrained, self.scratch = _make_encoder( backbone, features, - False, # Set to true of you want to train from scratch, uses ImageNet weights + False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, @@ -63,7 +55,6 @@ def __init__( self.scratch.output_conv = head - def forward(self, x): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -102,8 +93,7 @@ def __init__(self, path=None, non_negative=True, **kwargs): super().__init__(head, **kwargs) if path is not None: - self.load(path) + self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 8a954977800b..0dd87b59619c 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -10,8 +10,7 @@ class MidasNet(BaseModel): - """Network for monocular depth estimation. - """ + """Network for monocular depth estimation.""" def __init__(self, path=None, features=256, non_negative=True): """Init. @@ -27,7 +26,9 @@ def __init__(self, path=None, features=256, non_negative=True): use_pretrained = False if path is None else True - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 50e4acb5e53d..4d30744c46d3 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -6,15 +6,23 @@ import torch.nn as nn from .base_model import BaseModel -from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder class MidasNet_small(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): + """Network for monocular depth estimation.""" + + def __init__( + self, + path=None, + features=64, + backbone="efficientnet_lite3", + non_negative=True, + exportable=True, + channels_last=False, + align_corners=True, + blocks={"expand": True}, + ): """Init. Args: @@ -27,49 +35,57 @@ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_ne super(MidasNet_small, self).__init__() use_pretrained = False if path else True - + self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 - features1=features - features2=features - features3=features - features4=features + features1 = features + features2 = features + features3 = features + features4 = features self.expand = False - if "expand" in self.blocks and self.blocks['expand'] == True: + if "expand" in self.blocks and self.blocks["expand"] == True: self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 + features1 = features + features2 = features * 2 + features3 = features * 4 + features4 = features * 8 + + self.pretrained, self.scratch = _make_encoder( + self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable + ) - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) + self.scratch.activation = nn.ReLU(False) - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + self.scratch.refinenet4 = FeatureFusionBlock_custom( + features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet3 = FeatureFusionBlock_custom( + features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet2 = FeatureFusionBlock_custom( + features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet1 = FeatureFusionBlock_custom( + features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners + ) - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) - + if path: self.load(path) - def forward(self, x): """Forward pass. @@ -79,38 +95,35 @@ def forward(self, x): Returns: tensor: depth """ - if self.channels_last==True: + if self.channels_last == True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) - layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) - + layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) - path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - + out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) - def fuse_model(m): prev_previous_type = nn.Identity() - prev_previous_name = '' + prev_previous_name = "" previous_type = nn.Identity() - previous_name = '' + previous_name = "" for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) @@ -125,4 +138,4 @@ def fuse_model(m): prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) - previous_name = name \ No newline at end of file + previous_name = name diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py index 350cbc116626..aede0fa0c73f 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -1,7 +1,8 @@ -import numpy as np -import cv2 import math +import cv2 +import numpy as np + def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. @@ -28,13 +29,9 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): shape[1] = math.ceil(scale * shape[1]) # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), @@ -46,8 +43,7 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): class Resize(object): - """Resize sample to given size (width, height). - """ + """Resize sample to given size (width, height).""" def __init__( self, @@ -133,24 +129,14 @@ def get_size(self, width, height): # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) @@ -160,9 +146,7 @@ def get_size(self, width, height): return (new_width, new_height) def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize( @@ -180,9 +164,7 @@ def __call__(self, sample): ) if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), @@ -195,8 +177,7 @@ def __call__(self, sample): class NormalizeImage(object): - """Normlize image by given mean and std. - """ + """Normlize image by given mean and std.""" def __init__(self, mean, std): self.__mean = mean @@ -209,8 +190,7 @@ def __call__(self, sample): class PrepareForNet(object): - """Prepare sample for usage as network input. - """ + """Prepare sample for usage as network input.""" def __init__(self): pass diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py index ea46b1be88b2..41bdb566fd4f 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/vit.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -1,8 +1,9 @@ +import math +import types + +import timm import torch import torch.nn as nn -import timm -import types -import math import torch.nn.functional as F @@ -56,7 +57,7 @@ def forward(self, x): def forward_vit(pretrained, x): b, c, h, w = x.shape - glob = pretrained.model.forward_flex(x) + pretrained.model.forward_flex(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] @@ -117,9 +118,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]) B = x.shape[0] @@ -131,15 +130,11 @@ def forward_flex(self, x): x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed @@ -169,13 +164,9 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for out_feat in features] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + assert False, "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper @@ -287,9 +278,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained @@ -311,24 +300,18 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) + model = timm.create_model("vit_deit_base_distilled_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( @@ -358,12 +341,8 @@ def _make_vit_b_rn50_backbone( pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(get_activation("1")) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) @@ -419,12 +398,8 @@ def _make_vit_b_rn50_backbone( ), ) else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) + pretrained.act_postprocess1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + pretrained.act_postprocess2 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], @@ -468,16 +443,12 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): +def _make_pretrained_vitb_rn50_384(pretrained, use_readout="ignore", hooks=None, use_vit_only=False): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 9a9d3b5b6637..1428d42b2445 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,8 +1,9 @@ """Utils for monoDepth.""" -import sys import re -import numpy as np +import sys + import cv2 +import numpy as np import torch @@ -16,7 +17,6 @@ def read_pfm(path): tuple: (data, scale) """ with open(path, "rb") as file: - color = None width = None height = None @@ -74,9 +74,7 @@ def write_pfm(path, image, scale=1): if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") @@ -135,9 +133,7 @@ def resize_image(img): img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) + img_resized = torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() img_resized = img_resized.unsqueeze(0) return img_resized @@ -156,12 +152,11 @@ def resize_depth(depth, width, height): """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) return depth_resized + def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. @@ -174,7 +169,7 @@ def write_depth(path, depth, bits=1): depth_min = depth.min() depth_max = depth.max() - max_val = (2**(8*bits))-1 + max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py index 8c09ca1c72f7..9b52b199aa2c 100644 --- a/examples/images/diffusion/ldm/util.py +++ b/examples/images/diffusion/ldm/util.py @@ -1,11 +1,10 @@ import importlib +from inspect import isfunction -import torch -from torch import optim import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont +from torch import optim def log_txt_as_img(wh, xc, size=10): @@ -16,9 +15,9 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -39,7 +38,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -71,7 +70,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): if not "target" in config: - if config == '__is_first_stage__': + if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None @@ -89,9 +88,18 @@ def get_obj_from_str(string, reload=False): class AdamWwithEMAandWings(optim.Optimizer): # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): + def __init__( + self, + params, + lr=1.0e-3, + betas=(0.9, 0.999), + eps=1.0e-8, # TODO: check hyperparameters before using + weight_decay=1.0e-2, + amsgrad=False, + ema_decay=0.9999, # ema decay to match previous code + ema_power=1.0, + param_names=(), + ): """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -105,15 +113,22 @@ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: che raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names, + ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: - group.setdefault('amsgrad', False) + group.setdefault("amsgrad", False) @torch.no_grad() def step(self, closure=None): @@ -133,65 +148,66 @@ def step(self, closure=None): exp_avgs = [] exp_avg_sqs = [] ema_params_with_grad = [] - state_sums = [] max_exp_avg_sqs = [] state_steps = [] - amsgrad = group['amsgrad'] - beta1, beta2 = group['betas'] - ema_decay = group['ema_decay'] - ema_power = group['ema_power'] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + ema_decay = group["ema_decay"] + ema_power = group["ema_power"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of parameter values - state['param_exp_avg'] = p.detach().float().clone() + state["param_exp_avg"] = p.detach().float().clone() - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - ema_params_with_grad.append(state['param_exp_avg']) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + ema_params_with_grad.append(state["param_exp_avg"]) if amsgrad: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) # update the steps for each param group update - state['step'] += 1 + state["step"] += 1 # record the step after step update - state_steps.append(state['step']) - - optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) - - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + state_steps.append(state["step"]) + + optim._functional.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power) for param, ema_param in zip(params_with_grad, ema_params_with_grad): ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) - return loss \ No newline at end of file + return loss diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 713029fc677d..6d44df667fce 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -1,33 +1,28 @@ import argparse -import csv import datetime import glob -import importlib import os import sys import time +from functools import partial +import lightning.pytorch as pl import numpy as np import torch import torchvision -import lightning.pytorch as pl - - -from functools import partial - -from omegaconf import OmegaConf -from packaging import version -from PIL import Image -from prefetch_generator import BackgroundGenerator -from torch.utils.data import DataLoader, Dataset, Subset, random_split from ldm.models.diffusion.ddpm import LatentDiffusion - from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.strategies import ColossalAIStrategy, DDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities import rank_zero_info, rank_zero_only -from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger -from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset + LIGHTNING_PACK_NAME = "lightning.pytorch." from ldm.data.base import Txt2ImgIterableBaseDataset @@ -37,15 +32,15 @@ class DataLoaderX(DataLoader): -# A custom data loader class that inherits from DataLoader + # A custom data loader class that inherits from DataLoader def __iter__(self): # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator - #This is to enable data loading in the background to improve training performance + # This is to enable data loading in the background to improve training performance return BackgroundGenerator(super().__iter__()) def get_parser(**parser_kwargs): - #A function to create an ArgumentParser object and add arguments to it + # A function to create an ArgumentParser object and add arguments to it def str2bool(v): # A helper function to parse boolean values from command line arguments @@ -57,6 +52,7 @@ def str2bool(v): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") + # Create an ArgumentParser object with specifies kwargs parser = argparse.ArgumentParser(**parser_kwargs) @@ -160,6 +156,7 @@ def str2bool(v): return parser + # A function that returns the non-default arguments between two objects def nondefault_trainer_args(opt): # create an argument parser @@ -171,6 +168,7 @@ def nondefault_trainer_args(opt): # return all non-default arguments return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + # A dataset wrapper class to create a pytorch dataset from an arbitrary object class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" @@ -184,6 +182,7 @@ def __len__(self): def __getitem__(self, idx): return self.data[idx] + # A function to initialize worker processes def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() @@ -192,31 +191,33 @@ def worker_init_fn(_): worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): - #divide the dataset into equal parts for each worker + # divide the dataset into equal parts for each worker split_size = dataset.num_records // worker_info.num_workers - #set the sample IDs for the current worker + # set the sample IDs for the current worker # reset num_records to the true number to retain reliable length information - dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size] # set the seed for the current worker current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: return np.random.seed(np.random.get_state()[1][0] + worker_id) -#Provide functionality for creating data loaders based on provided dataset configurations -class DataModuleFromConfig(pl.LightningDataModule): - def __init__(self, - batch_size, - train=None, - validation=None, - test=None, - predict=None, - wrap=False, - num_workers=None, - shuffle_test_loader=False, - use_worker_init_fn=False, - shuffle_val_dataloader=False): +# Provide functionality for creating data loaders based on provided dataset configurations +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + ): super().__init__() # Set data module attributes self.batch_size = batch_size @@ -246,43 +247,47 @@ def prepare_data(self): def setup(self, stage=None): # Instantiate datasets from the dataset configs self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) - + # If wrap is true, create a WrappedDataset for each dataset if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): - #Check if the train dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) - #Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True + # Check if the train dataset is iterable + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) + # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the train dataset - return DataLoaderX(self.datasets["train"], - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False if is_iterable_dataset else True, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn, + ) def _val_dataloader(self, shuffle=False): - #Check if the validation dataset is iterable - if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + # Check if the validation dataset is iterable + if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the validation dataset - return DataLoaderX(self.datasets["validation"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _test_dataloader(self, shuffle=False): # Check if the test dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn @@ -292,21 +297,22 @@ def _test_dataloader(self, shuffle=False): # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) - return DataLoaderX(self.datasets["test"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _predict_dataloader(self, shuffle=False): - if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None - return DataLoaderX(self.datasets["predict"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn + ) class SetupCallback(Callback): @@ -338,10 +344,10 @@ def on_fit_start(self, trainer, pl_module): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - #Create trainstep checkpoint directory if necessary + # Create trainstep checkpoint directory if necessary if "callbacks" in self.lightning_config: - if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: - os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: + os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) @@ -349,8 +355,10 @@ def on_fit_start(self, trainer, pl_module): # Save project config and lightning config as YAML files print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) - OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), - os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + OmegaConf.save( + OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), + ) # Remove log directory if resuming training and directory already exists else: @@ -373,24 +381,25 @@ def on_fit_start(self, trainer, pl_module): # PyTorch Lightning callback for logging images during training and validation of a deep learning model class ImageLogger(Callback): - - def __init__(self, - batch_frequency, # Frequency of batches on which to log images - max_images, # Maximum number of images to log - clamp=True, # Whether to clamp pixel values to [-1,1] - increase_log_steps=True, # Whether to increase frequency of log steps exponentially - rescale=True, # Whether to rescale pixel values to [0,1] - disabled=False, # Whether to disable logging - log_on_batch_idx=False, # Whether to log on batch index instead of global step - log_first_step=False, # Whether to log on the first step - log_images_kwargs=None): # Additional keyword arguments to pass to log_images method + def __init__( + self, + batch_frequency, # Frequency of batches on which to log images + max_images, # Maximum number of images to log + clamp=True, # Whether to clamp pixel values to [-1,1] + increase_log_steps=True, # Whether to increase frequency of log steps exponentially + rescale=True, # Whether to rescale pixel values to [0,1] + disabled=False, # Whether to disable logging + log_on_batch_idx=False, # Whether to log on batch index instead of global step + log_first_step=False, # Whether to log on the first step + log_images_kwargs=None, + ): # Additional keyword arguments to pass to log_images method super().__init__() self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { # Dictionary of logger classes and their corresponding logging methods - pl.loggers.CSVLogger: self._testtube, + pl.loggers.CSVLogger: self._testtube, } # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] @@ -402,37 +411,39 @@ def __init__(self, self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step - @rank_zero_only # Ensure that only the first process in distributed training executes this method - def _testtube(self, # The PyTorch Lightning module - pl_module, # A dictionary of images to log. - images, # - batch_idx, # The batch index. - split # The split (train/val) on which to log the images - ): - # Method for logging images using test-tube logger + @rank_zero_only # Ensure that only the first process in distributed training executes this method + def _testtube( + self, # The PyTorch Lightning module + pl_module, # A dictionary of images to log. + images, # + batch_idx, # The batch index. + split, # The split (train/val) on which to log the images + ): + # Method for logging images using test-tube logger for k in images: grid = torchvision.utils.make_grid(images[k]) - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" # Add image grid to logger's experiment pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) @rank_zero_only - def log_local(self, - save_dir, - split, # The split (train/val) on which to log the images - images, # A dictionary of images to log - global_step, # The global step - current_epoch, # The current epoch. - batch_idx - ): - # Method for saving image grids to local file system + def log_local( + self, + save_dir, + split, # The split (train/val) on which to log the images + images, # A dictionary of images to log + global_step, # The global step + current_epoch, # The current epoch. + batch_idx, + ): + # Method for saving image grids to local file system root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) @@ -443,11 +454,15 @@ def log_local(self, Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): - #Function for logging images to both the logger and local file system. + # Function for logging images to both the logger and local file system. check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step # check if it's time to log an image batch - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + if ( + self.check_frequency(check_idx) + and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 + and callable(pl_module.log_images) + and self.max_images > 0 + ): # Get logger type and check if training mode is on logger = type(pl_module.logger) @@ -466,11 +481,12 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: - images[k] = torch.clamp(images[k], -1., 1.) + images[k] = torch.clamp(images[k], -1.0, 1.0) # Log images locally to file system - self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, - batch_idx) + self.log_local( + pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx + ) # log the images using the logger logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) @@ -482,13 +498,13 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): # The function checks if it's time to log an image batch def check_frequency(self, check_idx): - if ((check_idx % self.batch_freq) == 0 or - (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step + ): try: self.log_steps.pop(0) except IndexError as e: print(e) - pass return True return False @@ -503,7 +519,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") # log gradients during calibration if necessary - if hasattr(pl_module, 'calibrate_grad_norm'): + if hasattr(pl_module, "calibrate_grad_norm"): if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @@ -514,7 +530,7 @@ class CUDACallback(Callback): def on_train_start(self, trainer, pl_module): rank_zero_info("Training is starting") - #the method is called at the end of each training epoch + # the method is called at the end of each training epoch def on_train_end(self, trainer, pl_module): rank_zero_info("Training is ending") @@ -595,9 +611,11 @@ def on_train_epoch_end(self, trainer, pl_module): opt, unknown = parser.parse_known_args() # Verify the arguments are both specified if opt.name and opt.resume: - raise ValueError("-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint") + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) # Check if the "resume" option is specified, resume training from the checkpoint if it is true ckpt = None @@ -646,7 +664,7 @@ def on_train_epoch_end(self, trainer, pl_module): # Sets the seed for the random number generator to ensure reproducibility seed_everything(opt.seed) - # Initialize and save configuration using teh OmegaConf library. + # Initialize and save configuration using teh OmegaConf library. try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] @@ -676,7 +694,7 @@ def on_train_epoch_end(self, trainer, pl_module): config.model["params"].update({"use_fp16": False}) if ckpt is not None: - #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt + # If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt config.model["params"].update({"ckpt": ckpt}) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) @@ -688,17 +706,12 @@ def on_train_epoch_end(self, trainer, pl_module): # Default logger configs to log training metrics during the training process. default_logger_cfgs = { "wandb": { - "name": nowname, - "save_dir": logdir, - "offline": opt.debug, - "id": nowname, - } - , - "tensorboard": { - "save_dir": logdir, - "name": "diff_tb", - "log_graph": True - } + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + }, + "tensorboard": {"save_dir": logdir, "name": "diff_tb", "log_graph": True}, } # Set up the logger for TensorBoard @@ -722,11 +735,11 @@ def on_train_epoch_end(self, trainer, pl_module): # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "dirpath": ckptdir, - "filename": "{epoch:06}", - "verbose": True, - "save_last": True, - } + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } if hasattr(model, "monitor"): default_modelckpt_cfg["monitor"] = model.monitor default_modelckpt_cfg["save_top_k"] = 3 @@ -736,48 +749,47 @@ def on_train_epoch_end(self, trainer, pl_module): else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - if version.parse(pl.__version__) < version.parse('1.4.0'): + if version.parse(pl.__version__) < version.parse("1.4.0"): trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg) - #Create an empty OmegaConf configuration object + # Create an empty OmegaConf configuration object callbacks_cfg = OmegaConf.create() - - #Instantiate items according to the configs + + # Instantiate items according to the configs trainer_kwargs.setdefault("callbacks", []) setup_callback_config = { - "resume": opt.resume, # resume training if applicable - "now": now, - "logdir": logdir, # directory to save the log file - "ckptdir": ckptdir, # directory to save the checkpoint file - "cfgdir": cfgdir, # directory to save the configuration file - "config": config, # configuration dictionary - "lightning_config": lightning_config, # LightningModule configuration - } + "resume": opt.resume, # resume training if applicable + "now": now, + "logdir": logdir, # directory to save the log file + "ckptdir": ckptdir, # directory to save the checkpoint file + "cfgdir": cfgdir, # directory to save the configuration file + "config": config, # configuration dictionary + "lightning_config": lightning_config, # LightningModule configuration + } trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config)) - + image_logger_config = { - - "batch_frequency": 750, # how frequently to log images - "max_images": 4, # maximum number of images to log - "clamp": True # whether to clamp pixel values to [0,1] - } + "batch_frequency": 750, # how frequently to log images + "max_images": 4, # maximum number of images to log + "clamp": True, # whether to clamp pixel values to [0,1] + } trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config)) - + learning_rate_logger_config = { - "logging_interval": "step", # logging frequency (either 'step' or 'epoch') - # "log_momentum": True # whether to log momentum (currently commented out) - } + "logging_interval": "step", # logging frequency (either 'step' or 'epoch') + # "log_momentum": True # whether to log momentum (currently commented out) + } trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config)) - - metrics_over_trainsteps_checkpoint_config= { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + + metrics_over_trainsteps_checkpoint_config = { + "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), "filename": "{epoch:06}-{step:09}", "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } + "save_top_k": -1, + "every_n_train_steps": 10000, + "save_weights_only": True, + } trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config)) trainer_kwargs["callbacks"].append(CUDACallback()) @@ -805,7 +817,7 @@ def on_train_epoch_end(self, trainer, pl_module): ngpu = trainer_config["devices"] else: ngpu = 1 - if 'accumulate_grad_batches' in lightning_config.trainer: + if "accumulate_grad_batches" in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 @@ -814,8 +826,10 @@ def on_train_epoch_end(self, trainer, pl_module): if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr rank_zero_info( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" - .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr + ) + ) else: model.learning_rate = base_lr rank_zero_info("++++ NOT USING LR SCALING ++++") @@ -832,9 +846,11 @@ def melk(*args, **kwargs): def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb + pudb.set_trace() import signal + # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh index a8d79e99ccdf..50dab5de5b90 100755 --- a/examples/images/diffusion/scripts/download_first_stages.sh +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -38,4 +38,4 @@ unzip -o model.zip cd ../vq-f16 unzip -o model.zip -cd ../.. \ No newline at end of file +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 877538d4733d..4c386113dcc3 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -1,28 +1,30 @@ """make variations of input image""" -import argparse, os +import argparse +import os +from contextlib import nullcontext +from itertools import islice + +import numpy as np import PIL import torch -import numpy as np +from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast -from contextlib import nullcontext +from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from imwatermark import WatermarkEncoder - -from scripts.txt2img import put_watermark -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from utils import replace_module, getModelSize +from ldm.util import instantiate_from_config +from scripts.txt2img import put_watermark +from utils import replace_module def chunk(it, size): @@ -58,7 +60,7 @@ def load_img(path): image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2. * image - 1. + return 2.0 * image - 1.0 def main(): @@ -69,22 +71,13 @@ def main(): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) - parser.add_argument( - "--init-img", - type=str, - nargs="?", - help="path to the input image" - ) + parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/img2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples" ) parser.add_argument( @@ -96,7 +89,7 @@ def main(): parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) @@ -177,11 +170,7 @@ def main(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--use_int8", @@ -204,7 +193,7 @@ def main(): model = replace_module(model) # # to compute the model size # getModelSize(model) - + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -213,7 +202,7 @@ def main(): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -235,12 +224,12 @@ def main(): assert os.path.isfile(opt.init_img) init_image = load_img(opt.init_img).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) - assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]" t_enc = int(opt.strength * opt.ddim_steps) print(f"target t_enc is {t_enc} steps") @@ -261,14 +250,19 @@ def main(): # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") img = Image.fromarray(x_sample.astype(np.uint8)) img = put_watermark(img, wm_encoder) img.save(os.path.join(sample_path, f"{base_count:05}.png")) @@ -277,14 +271,14 @@ def main(): # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid = Image.fromarray(grid.astype(np.uint8)) grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index d6e6387a9a3b..afffcf1685e6 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -1,32 +1,35 @@ -import argparse, os, sys, glob -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm +import argparse +import glob +import os + import numpy as np import torch -from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from main import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm def make_batch(image, mask, device): image = np.array(Image.open(image).convert("RGB")) - image = image.astype(np.float32)/255.0 - image = image[None].transpose(0,3,1,2) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = np.array(Image.open(mask).convert("L")) - mask = mask.astype(np.float32)/255.0 - mask = mask[None,None] + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - masked_image = (1-mask)*image + masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) - batch[k] = batch[k]*2.0-1.0 + batch[k] = batch[k] * 2.0 - 1.0 return batch @@ -58,8 +61,7 @@ def make_batch(image, mask, device): config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") model = instantiate_from_config(config.model) - model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], - strict=False) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -74,25 +76,19 @@ def make_batch(image, mask, device): # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) - cc = torch.nn.functional.interpolate(batch["mask"], - size=c.shape[-2:]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) - shape = (c.shape[1]-1,)+c.shape[2:] - samples_ddim, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False) + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim, _ = sampler.sample( + S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False + ) x_samples_ddim = model.decode_first_stage(samples_ddim) - image = torch.clamp((batch["image"]+1.0)/2.0, - min=0.0, max=1.0) - mask = torch.clamp((batch["mask"]+1.0)/2.0, - min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, - min=0.0, max=1.0) + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1-mask)*image+mask*predicted_image - inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index e6eaaecab53e..763811665bbc 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -1,22 +1,22 @@ -import argparse, os, sys, glob -import clip -import torch -import torch.nn as nn -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid -import scann +import argparse +import glob +import os import time +from itertools import islice from multiprocessing import cpu_count -from ldm.util import instantiate_from_config, parallel_data_prefetch +import numpy as np +import scann +import torch +from einops import rearrange from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder +from ldm.util import instantiate_from_config, parallel_data_prefetch +from omegaconf import OmegaConf +from PIL import Image +from torchvision.utils import make_grid +from tqdm import tqdm, trange DATABASES = [ "openimages", @@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False): class Searcher(object): - def __init__(self, database, retriever_version='ViT-L/14'): + def __init__(self, database, retriever_version="ViT-L/14"): assert database in DATABASES # self.database = self.load_database(database) self.database_name = database - self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' - self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.searcher_savedir = f"data/rdm/searchers/{self.database_name}" + self.database_path = f"data/rdm/retrieval_databases/{self.database_name}" self.retriever = self.load_retriever(version=retriever_version) - self.database = {'embedding': [], - 'img_id': [], - 'patch_coords': []} + self.database = {"embedding": [], "img_id": [], "patch_coords": []} self.load_database() self.load_searcher() - def train_searcher(self, k, - metric='dot_product', - searcher_savedir=None): - - print('Start training searcher') - searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / - np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], - k, metric) + def train_searcher(self, k, metric="dot_product", searcher_savedir=None): + print("Start training searcher") + searcher = scann.scann_ops_pybind.builder( + self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric + ) self.searcher = searcher.score_brute_force().build() - print('Finish training searcher') + print("Finish training searcher") if searcher_savedir is not None: print(f'Save trained searcher under "{searcher_savedir}"') @@ -91,36 +86,40 @@ def train_searcher(self, k, def load_single_file(self, saved_embeddings): compressed = np.load(saved_embeddings) self.database = {key: compressed[key] for key in compressed.files} - print('Finished loading of clip embeddings.') + print("Finished loading of clip embeddings.") def load_multi_files(self, data_archive): out_data = {key: [] for key in self.database} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: out_data[key].append(d[key]) return out_data def load_database(self): - print(f'Load saved patch embedding from "{self.database_path}"') - file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + file_content = glob.glob(os.path.join(self.database_path, "*.npz")) if len(file_content) == 1: self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(self.load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in - self.database} + self.database = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database + } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') - def load_retriever(self, version='ViT-L/14', ): + def load_retriever( + self, + version="ViT-L/14", + ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() @@ -128,14 +127,14 @@ def load_retriever(self, version='ViT-L/14', ): return model def load_searcher(self): - print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + print(f"load searcher for database {self.database_name} from {self.searcher_savedir}") self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) - print('Finished loading searcher.') + print("Finished loading searcher.") def search(self, x, k): - if self.searcher is None and self.database['embedding'].shape[0] < 2e4: - self.train_searcher(k) # quickly fit searcher on the fly for small databases - assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if self.searcher is None and self.database["embedding"].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, "Cannot search with uninitialized searcher" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if len(x.shape) == 3: @@ -146,17 +145,19 @@ def search(self, x, k): nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) end = time.time() - out_embeddings = self.database['embedding'][nns] - out_img_ids = self.database['img_id'][nns] - out_pc = self.database['patch_coords'][nns] + out_embeddings = self.database["embedding"][nns] + out_img_ids = self.database["img_id"][nns] + out_pc = self.database["patch_coords"][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} + out = { + "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + "img_ids": out_img_ids, + "patch_coords": out_pc, + "queries": x, + "exec_time": end - start, + "nns": nns, + "q_embeddings": query_embeddings, + } return out @@ -173,20 +174,16 @@ def __call__(self, x, n): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) @@ -206,7 +203,7 @@ def __call__(self, x, n): parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) @@ -287,14 +284,14 @@ def __call__(self, x, n): parser.add_argument( "--database", type=str, - default='artbench-surrealism', + default="artbench-surrealism", choices=DATABASES, help="The database used for the search, only applied when --use_neighbors=True", ) parser.add_argument( "--use_neighbors", default=False, - action='store_true', + action="store_true", help="Include neighbors in addition to text prompt for conditioning", ) parser.add_argument( @@ -358,41 +355,43 @@ def __call__(self, x, n): uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) - c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - ) + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index 876fe3c3642f..740aae2435d2 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -1,21 +1,26 @@ -import argparse, os, sys, glob, datetime, yaml -import torch +import argparse +import datetime +import glob +import os +import sys import time -import numpy as np -from tqdm import trange +import numpy as np +import torch +import yaml +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image +from tqdm import trange -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config +rescale = lambda x: (x + 1.0) / 2.0 -rescale = lambda x: (x + 1.) / 2. def custom_to_pil(x): x = x.detach().cpu() - x = torch.clamp(x, -1., 1.) - x = (x + 1.) / 2. + x = torch.clamp(x, -1.0, 1.0) + x = (x + 1.0) / 2.0 x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) @@ -51,49 +56,51 @@ def logs2pil(logs, keys=["sample"]): @torch.no_grad() -def convsample(model, shape, return_intermediates=True, - verbose=True, - make_prog_row=False): - - +def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) + return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: - return model.progressive_denoising( - None, shape, verbose=True - ) + return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() -def convsample_ddim(model, steps, shape, eta=1.0 - ): +def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + eta=eta, + verbose=False, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - - +def make_convolutional_sample( + model, + batch_size, + vanilla=False, + custom_steps=None, + eta=1.0, +): log = dict() - shape = [batch_size, - model.model.diffusion_model.in_channels, - model.model.diffusion_model.image_size, - model.model.diffusion_model.image_size] + shape = [ + batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size, + ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: - sample, progrow = convsample(model, shape, - make_prog_row=True) + sample, progrow = convsample(model, shape, make_prog_row=True) else: - sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, - eta=eta) + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() @@ -101,32 +108,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non log["sample"] = x_sample log["time"] = t1 - t0 - log['throughput'] = sample.shape[0] / (t1 - t0) + log["throughput"] = sample.shape[0] / (t1 - t0) print(f'Throughput for this batch: {log["throughput"]}') return log + def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: - print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.") else: - print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') - + print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}") tstart = time.time() - n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): - logs = make_convolutional_sample(model, batch_size=batch_size, - vanilla=vanilla, custom_steps=custom_steps, - eta=eta) + logs = make_convolutional_sample( + model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta + ) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) if n_saved >= n_samples: - print(f'Finish after generating {n_saved} samples') + print(f"Finish after generating {n_saved} samples") break all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] @@ -135,7 +142,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None np.savez(nppath, all_img) else: - raise NotImplementedError('Currently only sampling for unconditional models supported.') + raise NotImplementedError("Currently only sampling for unconditional models supported.") print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") @@ -168,58 +175,33 @@ def get_parser(): nargs="?", help="load from logdir or checkpoint in logdir", ) - parser.add_argument( - "-n", - "--n_samples", - type=int, - nargs="?", - help="number of samples to draw", - default=50000 - ) + parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) parser.add_argument( "-e", "--eta", type=float, nargs="?", help="eta for ddim sampling (0.0 yields deterministic sampling)", - default=1.0 + default=1.0, ) parser.add_argument( "-v", "--vanilla_sample", default=False, - action='store_true', + action="store_true", help="vanilla sampling (default option is DDIM sampling)?", ) + parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") parser.add_argument( - "-l", - "--logdir", - type=str, - nargs="?", - help="extra logdir", - default="none" - ) - parser.add_argument( - "-c", - "--custom_steps", - type=int, - nargs="?", - help="number of steps for ddim and fastdpm sampling", - default=50 - ) - parser.add_argument( - "--batch_size", - type=int, - nargs="?", - help="the bs", - default=10 + "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50 ) + parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) - model.load_state_dict(sd,strict=False) + model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model @@ -233,8 +215,7 @@ def load_model(config, ckpt, gpu, eval_mode): else: pl_sd = {"state_dict": None} global_step = None - model = load_model_from_config(config.model, - pl_sd["state_dict"]) + model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step @@ -253,9 +234,9 @@ def load_model(config, ckpt, gpu, eval_mode): if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: - logdir = '/'.join(opt.resume.split('/')[:-1]) + logdir = "/".join(opt.resume.split("/")[:-1]) # idx = len(paths)-paths[::-1].index("logs")+1 - print(f'Logdir is {logdir}') + print(f"Logdir is {logdir}") except ValueError: paths = opt.resume.split("/") idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt @@ -278,7 +259,8 @@ def load_model(config, ckpt, gpu, eval_mode): if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] - if locallog == "": locallog = logdir.split(os.sep)[-2] + if locallog == "": + locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) @@ -301,13 +283,19 @@ def load_model(config, ckpt, gpu, eval_mode): sampling_file = os.path.join(logdir, "sampling_config.yaml") sampling_conf = vars(opt) - with open(sampling_file, 'w') as f: + with open(sampling_file, "w") as f: yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) - - run(model, imglogdir, eta=opt.eta, - vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, - batch_size=opt.batch_size, nplog=numpylogdir) + run( + model, + imglogdir, + eta=opt.eta, + vanilla=opt.vanilla_sample, + n_samples=opt.n_samples, + custom_steps=opt.custom_steps, + batch_size=opt.batch_size, + nplog=numpylogdir, + ) print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index 13622c4989fd..c0af17bdecaa 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -1,28 +1,18 @@ -import os -import sys -from copy import deepcopy - +import torch import yaml -from datetime import datetime - from diffusers import StableDiffusionPipeline -import torch - -from main import get_parser from ldm.modules.diffusionmodules.openaimodel import UNetModel if __name__ == "__main__": with torch.no_grad(): yaml_path = "../../train_colossalai.yaml" - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, "r", encoding="utf-8") as f: config = f.read() base_config = yaml.load(config, Loader=yaml.FullLoader) - unet_config = base_config['model']['params']['unet_config'] + unet_config = base_config["model"]["params"]["unet_config"] diffusion_model = UNetModel(**unet_config).to("cuda:0") - pipe = StableDiffusionPipeline.from_pretrained( - "/data/scratch/diffuser/stable-diffusion-v1-4" - ).to("cuda:0") + pipe = StableDiffusionPipeline.from_pretrained("/data/scratch/diffuser/stable-diffusion-v1-4").to("cuda:0") dif_model_2 = pipe.unet random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") @@ -35,4 +25,4 @@ out_1 = diffusion_model(random_input_, time_stamp, context_) out_2 = dif_model_2(random_input_2, time_stamp2, context_2) print(out_1.shape) - print(out_2['sample'].shape) \ No newline at end of file + print(out_2["sample"].shape) diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py index f93f8a6e7076..9bfc9fc7d9cb 100644 --- a/examples/images/diffusion/scripts/tests/test_watermark.py +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -5,14 +5,14 @@ def testit(img_path): bgr = cv2.imread(img_path) - decoder = WatermarkDecoder('bytes', 136) - watermark = decoder.decode(bgr, 'dwtDct') + decoder = WatermarkDecoder("bytes", 136) + watermark = decoder.decode(bgr, "dwtDct") try: - dec = watermark.decode('utf-8') + dec = watermark.decode("utf-8") except: dec = "null" print(dec) if __name__ == "__main__": - fire.Fire(testit) \ No newline at end of file + fire.Fire(testit) diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py index 1e7904889c01..1df0baa7e5cf 100644 --- a/examples/images/diffusion/scripts/train_searcher.py +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -1,33 +1,39 @@ -import os, sys -import numpy as np -import scann import argparse import glob +import os +import sys from multiprocessing import cpu_count -from tqdm import tqdm +import numpy as np +import scann from ldm.util import parallel_data_prefetch +from tqdm import tqdm def search_bruteforce(searcher): return searcher.score_brute_force().build() -def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search): - return searcher.tree(num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_sample_size=partioning_trainsize). \ - score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() +def search_partioned_ah( + searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search +): + return ( + searcher.tree( + num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize + ) + .score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold) + .reorder(reorder_k) + .build() + ) def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): - return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( - reorder_k).build() - -def load_datapool(dpath): + return ( + searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + ) +def load_datapool(dpath): def load_single_file(saved_embeddings): compressed = np.load(saved_embeddings) database = {key: compressed[key] for key in compressed.files} @@ -35,23 +41,26 @@ def load_single_file(saved_embeddings): def load_multi_files(data_archive): database = {key: [] for key in data_archive[0].files} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: database[key].append(d[key]) return database print(f'Load saved patch embedding from "{dpath}"') - file_content = glob.glob(os.path.join(dpath, '*.npz')) + file_content = glob.glob(os.path.join(dpath, "*.npz")) if len(file_content) == 1: data_pool = load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + data_pool = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys() + } else: raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') @@ -59,16 +68,17 @@ def load_multi_files(data_archive): return data_pool -def train_searcher(opt, - metric='dot_product', - partioning_trainsize=None, - reorder_k=None, - # todo tune - aiq_thld=0.2, - dims_per_block=2, - num_leaves=None, - num_leaves_to_search=None,): - +def train_searcher( + opt, + metric="dot_product", + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None, +): data_pool = load_datapool(opt.database) k = opt.knn @@ -77,71 +87,83 @@ def train_searcher(opt, # normalize # embeddings = - searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) - pool_size = data_pool['embedding'].shape[0] - - print(*(['#'] * 100)) - print('Initializing scaNN searcher with the following values:') - print(f'k: {k}') - print(f'metric: {metric}') - print(f'reorder_k: {reorder_k}') - print(f'anisotropic_quantization_threshold: {aiq_thld}') - print(f'dims_per_block: {dims_per_block}') - print(*(['#'] * 100)) - print('Start training searcher....') - print(f'N samples in pool is {pool_size}') + searcher = scann.scann_ops_pybind.builder( + data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric + ) + pool_size = data_pool["embedding"].shape[0] + + print(*(["#"] * 100)) + print("Initializing scaNN searcher with the following values:") + print(f"k: {k}") + print(f"metric: {metric}") + print(f"reorder_k: {reorder_k}") + print(f"anisotropic_quantization_threshold: {aiq_thld}") + print(f"dims_per_block: {dims_per_block}") + print(*(["#"] * 100)) + print("Start training searcher....") + print(f"N samples in pool is {pool_size}") # this reflects the recommended design choices proposed at # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md if pool_size < 2e4: - print('Using brute force search.') + print("Using brute force search.") searcher = search_bruteforce(searcher) elif 2e4 <= pool_size and pool_size < 1e5: - print('Using asymmetric hashing search and reordering.') + print("Using asymmetric hashing search and reordering.") searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) else: - print('Using using partioning, asymmetric hashing search and reordering.') + print("Using using partioning, asymmetric hashing search and reordering.") if not partioning_trainsize: - partioning_trainsize = data_pool['embedding'].shape[0] // 10 + partioning_trainsize = data_pool["embedding"].shape[0] // 10 if not num_leaves: num_leaves = int(np.sqrt(pool_size)) if not num_leaves_to_search: num_leaves_to_search = max(num_leaves // 20, 1) - print('Partitioning params:') - print(f'num_leaves: {num_leaves}') - print(f'num_leaves_to_search: {num_leaves_to_search}') + print("Partitioning params:") + print(f"num_leaves: {num_leaves}") + print(f"num_leaves_to_search: {num_leaves_to_search}") # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) - searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search) + searcher = search_partioned_ah( + searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search + ) - print('Finish training searcher') + print("Finish training searcher") searcher_savedir = opt.target_path os.makedirs(searcher_savedir, exist_ok=True) searcher.serialize(searcher_savedir) print(f'Saved trained searcher under "{searcher_savedir}"') -if __name__ == '__main__': + +if __name__ == "__main__": sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() - parser.add_argument('--database', - '-d', - default='data/rdm/retrieval_databases/openimages', - type=str, - help='path to folder containing the clip feature of the database') - parser.add_argument('--target_path', - '-t', - default='data/rdm/searchers/openimages', - type=str, - help='path to the target folder where the searcher shall be stored.') - parser.add_argument('--knn', - '-k', - default=20, - type=int, - help='number of nearest neighbors, for which the searcher shall be optimized') - - opt, _ = parser.parse_known_args() - - train_searcher(opt,) \ No newline at end of file + parser.add_argument( + "--database", + "-d", + default="data/rdm/retrieval_databases/openimages", + type=str, + help="path to folder containing the clip feature of the database", + ) + parser.add_argument( + "--target_path", + "-t", + default="data/rdm/searchers/openimages", + type=str, + help="path to the target folder where the searcher shall be stored.", + ) + parser.add_argument( + "--knn", + "-k", + default=20, + type=int, + help="number of nearest neighbors, for which the searcher shall be optimized", + ) + + opt, _ = parser.parse_known_args() + + train_searcher( + opt, + ) diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 364ebac6c67b..feb17b9f77ae 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -1,29 +1,34 @@ -import argparse, os +import argparse +import os +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from einops import rearrange from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from torch import autocast + from contextlib import nullcontext -from imwatermark import WatermarkEncoder -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from utils import replace_module, getModelSize +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from torch import autocast +from utils import replace_module torch.set_grad_enabled(False) + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -55,14 +60,10 @@ def parse_args(): type=str, nargs="?", default="a professional photograph of an astronaut riding a triceratops", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--steps", @@ -72,17 +73,17 @@ def parse_args(): ) parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) parser.add_argument( "--dpm", - action='store_true', + action="store_true", help="use DPM (2) sampler", ) parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) parser.add_argument( @@ -162,11 +163,7 @@ def parse_args(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--repeat", @@ -187,7 +184,7 @@ def parse_args(): def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - img = wm_encoder.encode(img, 'dwtDct') + img = wm_encoder.encode(img, "dwtDct") img = Image.fromarray(img[:, :, ::-1]) return img @@ -197,17 +194,17 @@ def main(opt): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) - + # quantize model if opt.use_int8: model = replace_module(model) # # to compute the model size # getModelSize(model) - + if opt.plms: sampler = PLMSSampler(model) elif opt.dpm: @@ -221,7 +218,7 @@ def main(opt): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -248,56 +245,55 @@ def main(opt): start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(), \ - precision_scope("cuda"), \ - model.ema_scope(): - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - sample_count += 1 - - all_samples.append(x_samples) - - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - grid = Image.fromarray(grid.astype(np.uint8)) - grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample( + S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py index c954b22ca190..92ed0b4dfd0a 100644 --- a/examples/images/diffusion/scripts/utils.py +++ b/examples/images/diffusion/scripts/utils.py @@ -1,6 +1,7 @@ import bitsandbytes as bnb -import torch.nn as nn import torch +import torch.nn as nn + class Linear8bit(nn.Linear): def __init__( @@ -12,11 +13,9 @@ def __init__( memory_efficient_backward=False, threshold=6.0, weight_data=None, - bias_data=None + bias_data=None, ): - super(Linear8bit, self).__init__( - input_features, output_features, bias - ) + super(Linear8bit, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.bias = bias_data self.state.threshold = threshold @@ -24,13 +23,12 @@ def __init__( self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) self.weight = weight_data self.quant() - - def quant(self): + def quant(self): weight = self.weight.data.contiguous().half().cuda() CB, _, SCB, _, _ = bnb.functional.double_quant(weight) delattr(self, "weight") @@ -41,32 +39,34 @@ def quant(self): def forward(self, x): self.state.is_training = self.training - + if self.bias is not None and self.bias.dtype != torch.float16: self.bias.data = self.bias.data.half() - + self.state.CB = self.weight.data self.state.SCB = self.SCB.data - + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) del self.state.CxB return out + def replace_module(model): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_module(module) - if isinstance(module, nn.Linear) and "out_proj" not in name: + if isinstance(module, nn.Linear) and "out_proj" not in name: model._modules[name] = Linear8bit( - input_features=module.in_features, - output_features=module.out_features, - threshold=6.0, - weight_data=module.weight, - bias_data=module.bias, - ) + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) return model + def getModelSize(model): param_size = 0 param_sum = 0 @@ -79,5 +79,5 @@ def getModelSize(model): buffer_size += buffer.nelement() * buffer.element_size() buffer_sum += buffer.nelement() all_size = (param_size + buffer_size) / 1024 / 1024 - print('Model Size: {:.3f}MB'.format(all_size)) + print("Model Size: {:.3f}MB".format(all_size)) return (param_size, param_sum, buffer_size, buffer_sum, all_size) diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py index a24d54167640..13d9f8927801 100644 --- a/examples/images/diffusion/setup.py +++ b/examples/images/diffusion/setup.py @@ -1,13 +1,13 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='latent-diffusion', - version='0.0.1', - description='', + name="latent-diffusion", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], -) \ No newline at end of file +) diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh index 7f1a1bd14615..c56ed7876e5a 100755 --- a/examples/images/diffusion/train_colossalai.sh +++ b/examples/images/diffusion/train_colossalai.sh @@ -3,4 +3,3 @@ TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt - diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh index 78fe765488c6..8304d6fa8b4f 100644 --- a/examples/images/diffusion/train_ddp.sh +++ b/examples/images/diffusion/train_ddp.sh @@ -1,5 +1,5 @@ -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 -DIFFUSERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md index ba4c1a71034a..4e9febbc5fa8 100644 --- a/examples/images/dreambooth/README.md +++ b/examples/images/dreambooth/README.md @@ -93,7 +93,7 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ ``` ## New API -We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. +We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. @@ -111,7 +111,7 @@ For more information about the booster API you can refer to https://colossalai.o | low_level_zero | 4 | 8 | 28.87 | 2.02 | The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink. -We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared +We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared the memory cost and the throughput for the plugins. diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index 33219b2caa29..8ce4dc3bbd80 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -1,16 +1,16 @@ -''' +""" torchrun --standalone --nproc_per_node=1 debug.py -''' +""" from diffusers import AutoencoderKL import colossalai -from colossalai.zero import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext path = "/data/scratch/diffuser/stable-diffusion-v1-4" colossalai.launch_from_torch(config={}) -with ColoInitContext(device='cpu'): +with ColoInitContext(device="cpu"): vae = AutoencoderKL.from_pretrained( path, subfolder="vae", diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py index c342821c7830..ff317827aff7 100644 --- a/examples/images/dreambooth/inference.py +++ b/examples/images/dreambooth/inference.py @@ -1,7 +1,7 @@ -from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch +from diffusers import DiffusionPipeline -model_id = +model_id = "" print(f"Loading model... from{model_id}") pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") diff --git a/examples/images/dreambooth/train_dreambooth.py b/examples/images/dreambooth/train_dreambooth.py index b989955f7fb7..9b66089b2752 100644 --- a/examples/images/dreambooth/train_dreambooth.py +++ b/examples/images/dreambooth/train_dreambooth.py @@ -104,8 +104,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -118,17 +120,18 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -165,16 +168,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), - ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -192,8 +196,10 @@ def parse_args(input_args=None): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -203,7 +209,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -269,12 +276,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -350,7 +359,8 @@ def main(args): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -380,9 +390,9 @@ def main(args): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): @@ -456,8 +466,9 @@ def main(args): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -470,8 +481,9 @@ def main(args): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -506,9 +518,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -520,11 +530,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -542,10 +550,12 @@ def collate_fn(examples): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -641,8 +651,11 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 9b2ed3b971ae..1a7f8da7f7d0 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -117,8 +117,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -131,8 +133,10 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--offload_optim_frac", @@ -144,13 +148,14 @@ def parse_args(input_args=None): "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -181,16 +186,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -202,18 +208,22 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -223,7 +233,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -292,12 +303,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -391,9 +404,9 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -460,15 +473,14 @@ def main(args): if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -482,36 +494,37 @@ def main(args): # Use Booster API to use Gemini/Zero with ColossalAI booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), - lr=args.learning_rate, - initial_scale=2**5, - clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # prepare dataset logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) - train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - test=args.test_run) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + test=args.test_run, + ) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] @@ -527,9 +540,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -541,11 +552,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -664,7 +673,7 @@ def collate_fn(examples): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 654bce36ccb7..ea6dde8bb578 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -28,8 +28,6 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer -from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() @@ -122,8 +120,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -136,8 +136,10 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--placement", @@ -149,13 +151,14 @@ def parse_args(input_args=None): "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -186,16 +189,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -206,18 +210,22 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -227,7 +235,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -293,12 +302,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -392,9 +403,9 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -461,19 +472,17 @@ def main(args): if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) unet.requires_grad_(False) # Set correct lora layers @@ -492,7 +501,7 @@ def main(args): lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) + AttnProcsLayers(unet.attn_processors) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -506,22 +515,21 @@ def main(args): # Use Booster API to use Gemini/Zero with ColossalAI booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), - lr=args.learning_rate, - initial_scale=2**5, - clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -552,9 +560,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -566,11 +572,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -689,7 +693,7 @@ def collate_fn(examples): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: diff --git a/examples/images/dreambooth/train_dreambooth_inpaint.py b/examples/images/dreambooth/train_dreambooth_inpaint.py index 774cd4c458e9..32f1b4959879 100644 --- a/examples/images/dreambooth/train_dreambooth_inpaint.py +++ b/examples/images/dreambooth/train_dreambooth_inpaint.py @@ -126,8 +126,10 @@ def parse_args(): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be" - " sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -140,17 +142,18 @@ def parse_args(): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -186,16 +189,17 @@ def parse_args(): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -213,17 +217,21 @@ def parse_args(): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], - help=("Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU."), + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -283,12 +291,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -369,7 +379,8 @@ def main(): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -382,25 +393,25 @@ def main(): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None) + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None + ) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, - batch_size=args.sample_batch_size, - num_workers=1) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size, num_workers=1 + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) transform_to_pil = transforms.ToPILImage() - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): bsz = len(example["prompt"]) fake_images = torch.rand((3, args.resolution, args.resolution)) transform_to_pil = transforms.ToPILImage() @@ -457,8 +468,9 @@ def main(): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -471,8 +483,9 @@ def main(): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -494,10 +507,12 @@ def main(): ) def collate_fn(examples): - image_transforms = transforms.Compose([ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - ]) + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + ) input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -545,10 +560,9 @@ def collate_fn(examples): batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -566,10 +580,12 @@ def collate_fn(examples): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -622,16 +638,19 @@ def collate_fn(examples): latents = latents * 0.18215 # Convert masked images to latent space - masked_latents = vae.encode(batch["masked_images"].reshape( - batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = vae.encode( + batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) + ).latent_dist.sample() masked_latents = masked_latents * 0.18215 masks = batch["masks"] # resize the mask to latents shape as we concatenate the mask to the latents - mask = torch.stack([ - torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) - for mask in masks - ]) + mask = torch.stack( + [ + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) + for mask in masks + ] + ) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents @@ -680,8 +699,11 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/images/resnet/eval.py b/examples/images/resnet/eval.py index 657708ec3ff2..526e41a2850f 100644 --- a/examples/images/resnet/eval.py +++ b/examples/images/resnet/eval.py @@ -1,7 +1,6 @@ import argparse import torch -import torch.nn as nn import torchvision import torchvision.transforms as transforms @@ -9,15 +8,15 @@ # Parse Arguments # ============================== parser = argparse.ArgumentParser() -parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") args = parser.parse_args() # ============================== # Prepare Test Dataset # ============================== # CIFAR-10 dataset -test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) # Data loader test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) @@ -26,7 +25,7 @@ # Load Model # ============================== model = torchvision.models.resnet18(num_classes=10).cuda() -state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") model.load_state_dict(state_dict) # ============================== @@ -45,4 +44,4 @@ total += labels.size(0) correct += (predicted == labels).sum().item() - print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/images/resnet/requirements.txt b/examples/images/resnet/requirements.txt index 3c7da7743702..46b7da7d4870 100644 --- a/examples/images/resnet/requirements.txt +++ b/examples/images/resnet/requirements.txt @@ -2,4 +2,4 @@ colossalai torch torchvision tqdm -pytest \ No newline at end of file +pytest diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index fa300395c9f3..13df516d4189 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -30,23 +30,19 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform transform_train = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) transform_test = transforms.ToTensor() # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -70,14 +66,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -91,7 +94,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -100,19 +103,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero", "gemini"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -136,13 +140,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -168,18 +172,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -191,14 +194,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e6c52c4e97fd..7d54020f85c4 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -2,44 +2,47 @@ def parse_demo_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model", help="The path of your saved model after finetuning." + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", ) parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--tp_size", - type=int, - default=1, - help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") - parser.add_argument("--pp_size", - type=int, - default=1, - help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") - parser.add_argument("--learning_rate", - type=float, - default=3e-4, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.3, - help="Ratio of warmup steps against total training steps.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--tp_size", + type=int, + default=1, + help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--pp_size", + type=int, + default=1, + help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.3, help="Ratio of warmup steps against total training steps." + ) parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") @@ -49,29 +52,30 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to a pretrained model or model identifier from huggingface.co/models.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to a pretrained model or model identifier from huggingface.co/models.", + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", + ) + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size (per dp group) for the training dataloader." ) - parser.add_argument("--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader.") parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py index 77a8ad525056..5361fe9a3bad 100644 --- a/examples/images/vit/data.py +++ b/examples/images/vit/data.py @@ -4,13 +4,11 @@ class BeansDataset(Dataset): - - def __init__(self, image_processor, tp_size=1, split='train'): - + def __init__(self, image_processor, tp_size=1, split="train"): super().__init__() self.image_processor = image_processor - self.ds = load_dataset('beans')[split] - self.label_names = self.ds.features['labels'].names + self.ds = load_dataset("beans")[split] + self.label_names = self.ds.features["labels"].names while len(self.label_names) % tp_size != 0: # ensure that the number of labels is multiple of tp_size self.label_names.append(f"pad_label_{len(self.label_names)}") @@ -26,13 +24,13 @@ def __getitem__(self, idx): return self.inputs[idx] def process_example(self, example): - input = self.image_processor(example['image'], return_tensors='pt') - input['labels'] = example['labels'] + input = self.image_processor(example["image"], return_tensors="pt") + input["labels"] = example["labels"] return input def beans_collator(batch): return { - 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), - 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64) + "pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0), + "labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64), } diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt index edad87ca380f..69e41c61cd67 100644 --- a/examples/images/vit/requirements.txt +++ b/examples/images/vit/requirements.txt @@ -3,4 +3,4 @@ torch >= 1.8.1 numpy>=1.24.1 tqdm>=4.61.2 transformers>=4.20.0 -datasets \ No newline at end of file +datasets diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index d822fe23ecf0..b770bc9cfb95 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -25,18 +25,16 @@ def format_num(num: int, bytes=False): def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224): - pixel_values = torch.randn(batch_size, - num_channels, - height, - width, - device=torch.cuda.current_device(), - dtype=torch.float) + pixel_values = torch.randn( + batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float + ) labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) return dict(pixel_values=pixel_values, labels=labels) def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -44,7 +42,6 @@ def colo_memory_cap(size_in_GB): def main(): - args = parse_benchmark_args() # Launch ColossalAI @@ -75,22 +72,24 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - plugin = HybridParallelPlugin(tp_size=2, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - precision='fp16', - initial_scale=1) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Set optimizer @@ -119,12 +118,9 @@ def criterion(outputs, inputs): if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: # run pipeline forward backward batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + batch, model, criterion, optimizer, return_loss=True, return_outputs=True + ) else: outputs = model(**batch) loss = criterion(outputs, None) @@ -146,7 +142,8 @@ def criterion(outputs, inputs): f"plugin: {args.plugin}, " f"throughput: {throughput}, " f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + ranks=[0], + ) torch.cuda.empty_cache() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 206d8694b8f5..81009b3707b6 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -25,19 +25,21 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], - data_iter: Iterator, booster: Booster): +def run_forward_backward( + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, +): if optimizer is not None: optimizer.zero_grad() if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # run pipeline forward backward when enabling pp in hybrid parallel plugin - output_dict = booster.execute_pipeline(data_iter, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) - loss, outputs = output_dict['loss'], output_dict['outputs'] + output_dict = booster.execute_pipeline( + data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True + ) + loss, outputs = output_dict["loss"], output_dict["outputs"] else: batch = next(data_iter) batch = move_to_cuda(batch, torch.cuda.current_device()) @@ -49,9 +51,16 @@ def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Call return loss, outputs -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], - lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, + dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): torch.cuda.synchronize() num_steps = len(dataloader) @@ -61,12 +70,11 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar tp_rank = dist.get_rank(booster.plugin.tp_group) dp_rank = dist.get_rank(booster.plugin.dp_group) - enable_pbar = tp_rank == 0 and dp_rank == 0 \ - and booster.plugin.stage_manager.is_last_stage() + enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage() model.train() - with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: + with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar: for _ in pbar: loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() @@ -74,13 +82,18 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C # Print batch loss if enable_pbar: - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) @torch.no_grad() -def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], - eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def evaluate_model( + epoch: int, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + eval_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): torch.cuda.synchronize() model.eval() accum_loss = torch.zeros(1, device=torch.cuda.current_device()) @@ -99,13 +112,13 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], to_accum = to_accum and booster.plugin.stage_manager.is_last_stage() if to_accum: - accum_loss += (loss / len(eval_dataloader)) + accum_loss += loss / len(eval_dataloader) logits = outputs["logits"] preds = torch.argmax(logits, dim=1) labels = batch["labels"] total_num += batch["labels"].shape[0] - accum_correct += (torch.sum(preds == labels)) + accum_correct += torch.sum(preds == labels) dist.all_reduce(accum_loss) dist.all_reduce(total_num) @@ -113,13 +126,14 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], avg_loss = "{:.4f}".format(accum_loss.item()) accuracy = "{:.4f}".format(accum_correct.item() / total_num.item()) if coordinator.is_master(): - print(f"Evaluation result for epoch {epoch + 1}: \ + print( + f"Evaluation result for epoch {epoch + 1}: \ average_loss={avg_loss}, \ - accuracy={accuracy}.") + accuracy={accuracy}." + ) def main(): - args = parse_demo_args() # Launch ColossalAI @@ -136,14 +150,14 @@ def main(): transformers.utils.logging.set_verbosity_error() # Reset tp_size and pp_size to 1 if not using hybrid parallel. - if args.plugin != 'hybrid_parallel': + if args.plugin != "hybrid_parallel": args.tp_size = 1 args.pp_size = 1 # Prepare Dataset image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) - train_dataset = BeansDataset(image_processor, args.tp_size, split='train') - eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation') + train_dataset = BeansDataset(image_processor, args.tp_size, split="train") + eval_dataset = BeansDataset(image_processor, args.tp_size, split="validation") num_labels = train_dataset.num_labels # Load pretrained ViT model @@ -151,9 +165,9 @@ def main(): config.num_labels = num_labels config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} - model = ViTForImageClassification.from_pretrained(args.model_name_or_path, - config=config, - ignore_mismatched_sizes=True) + model = ViTForImageClassification.from_pretrained( + args.model_name_or_path, config=config, ignore_mismatched_sizes=True + ) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing @@ -162,37 +176,35 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - plugin = HybridParallelPlugin(tp_size=args.tp_size, - pp_size=args.pp_size, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - precision='fp16', - initial_scale=1) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=args.tp_size, + pp_size=args.pp_size, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) else: raise ValueError(f"Plugin with name {args.plugin} is not supported!") logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare dataloader - train_dataloader = plugin.prepare_dataloader(train_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) - eval_dataloader = plugin.prepare_dataloader(eval_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) + train_dataloader = plugin.prepare_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) + eval_dataloader = plugin.prepare_dataloader( + eval_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) @@ -204,17 +216,15 @@ def criterion(outputs, inputs): # Set lr scheduler total_steps = len(train_dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=(len(train_dataloader) * args.num_epoch), - warmup_steps=num_warmup_steps) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=(len(train_dataloader) * args.num_epoch), warmup_steps=num_warmup_steps + ) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - criterion=criterion, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) # Finetuning logger.info(f"Start finetuning", ranks=[0]) diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 67ff13bb5f5e..738f43dc0619 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -11,7 +11,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def print_perf_stats(latency_set, config, bs, warmup=3): @@ -25,7 +25,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): avg = sum(latency_set) / count num_layers = getattr(config, "num_layers", config.num_hidden_layers) num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 + num_bytes = 2 # float16 print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) @@ -53,7 +53,7 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -77,7 +77,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -89,11 +89,11 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index d2016a4587e6..6e49fa80c812 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -12,7 +12,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def init_to_get_rotary(self, base=10000): @@ -28,8 +28,9 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) + ) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -75,8 +76,8 @@ def run_llama_test(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 @@ -105,7 +106,7 @@ def run_llama_test(args): def check_llama(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test(args) @@ -117,11 +118,11 @@ def test_llama(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py index ae8b2269a534..10bd367fda5b 100644 --- a/examples/language/bert/benchmark.py +++ b/examples/language/bert/benchmark.py @@ -32,9 +32,7 @@ class RandintDataset(Dataset): - def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int): - self._sequence_length = sequence_length self._vocab_size = vocab_size self._n_class = n_class @@ -42,10 +40,13 @@ def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n self._datas = torch.randint( low=0, high=self._vocab_size, - size=(self._dataset_length, self._sequence_length,), + size=( + self._dataset_length, + self._sequence_length, + ), dtype=torch.long, ) - self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) + self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) def __len__(self): return self._dataset_length @@ -59,13 +60,15 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, @@ -88,13 +91,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -103,10 +106,9 @@ def main(): # Prepare Dataloader # ============================== - train_dataset = RandintDataset(dataset_length=DATASET_LEN, - sequence_length=SEQ_LEN, - vocab_size=VOCAB_SIZE, - n_class=NUM_LABELS) + train_dataset = RandintDataset( + dataset_length=DATASET_LEN, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE, n_class=NUM_LABELS + ) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE) # ==================================== @@ -159,16 +161,12 @@ def main(): # Benchmark model # ============================== - results = benchmark(model, - booster, - optimizer, - lr_scheduler, - train_dataloader, - criterion=criterion, - epoch_num=NUM_EPOCHS) + results = benchmark( + model, booster, optimizer, lr_scheduler, train_dataloader, criterion=criterion, epoch_num=NUM_EPOCHS + ) coordinator.print_on_master(results) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py index 886017a41826..04d55cb2e7b6 100644 --- a/examples/language/bert/benchmark_utils.py +++ b/examples/language/bert/benchmark_utils.py @@ -112,8 +112,9 @@ def benchmark( start_time = time() for epoch in range(epoch_num): - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}/{epoch_num}]', - disable=not DistCoordinator().is_master()) as pbar: + with tqdm( + dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() + ) as pbar: for data in pbar: inputs, labels = data[0].cuda(), data[1].cuda() outputs = model(inputs, labels=labels) @@ -137,7 +138,9 @@ def benchmark( } logger.info(fmt({f"Memory results (batch_size={batch_size})": memory[f"batch_size_{batch_size}"]})) - throughput[f"batch_size_{batch_size}"] = {"throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time))} + throughput[f"batch_size_{batch_size}"] = { + "throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time)) + } logger.info(fmt({f"Throughput results (batch_size={batch_size})": throughput[f"batch_size_{batch_size}"]})) results["throughput"] = throughput diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index fb6e4332c2f9..563cfa58d5f6 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,5 +1,4 @@ import argparse -from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -7,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Adam, Optimizer +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm @@ -22,7 +21,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -109,7 +107,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master() and results is not None: - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results @@ -120,13 +118,20 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, - train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) @@ -135,20 +140,17 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) - with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar: + with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(train_dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data) @@ -156,7 +158,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -168,26 +170,28 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, default="bert", help="bert or albert", ) - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() - if args.model_type == 'bert': + if args.model_type == "bert": model_name = "bert-base-uncased" - elif args.model_type == 'albert': + elif args.model_type == "albert": model_name = "albert-xxlarge-v2" else: raise RuntimeError @@ -204,36 +208,35 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=1, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - zero_stage=1, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) booster = Booster(plugin=plugin, **booster_kwargs) # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -283,10 +286,9 @@ def _criterion(outputs, inputs): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) # ============================== # Train model @@ -294,14 +296,22 @@ def _criterion(outputs, inputs): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, - data_builder.eval_splits, booster, coordinator) + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py index 35e44608f810..75968a0b1da9 100644 --- a/examples/language/gpt/experiments/auto_offload/model_zoo.py +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -2,22 +2,20 @@ import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -25,7 +23,6 @@ def forward(self, input_ids, attention_mask): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -36,6 +33,7 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + def get_gpt2_components(model_type: str, batch_size: int): vocab_size = 1024 seq_len = 8 @@ -62,4 +60,4 @@ def gpt2_data_gen(device="cuda"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt index 3ebde8d460aa..137a69e80498 100644 --- a/examples/language/gpt/experiments/auto_offload/requirements.txt +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -1,2 +1,2 @@ colossalai >= 0.1.12 -torch >= 1.8.1 \ No newline at end of file +torch >= 1.8.1 diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 89415c23f93c..521527da51e0 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -3,7 +3,6 @@ import pytest import torch -from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map import colossalai @@ -14,18 +13,19 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn from colossalai.utils import get_current_device +from model_zoo import GPTLMLoss, get_gpt2_components def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--solver_type', type=str, default='asyn') - parser.add_argument('--memory_budget', type=float, default=16) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--solver_type", type=str, default="asyn") + parser.add_argument("--memory_budget", type=float, default=16) return parser.parse_args() -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 solver_type = args.solver_type @@ -34,10 +34,15 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = GPTLMLoss() start_time = time.time() @@ -80,18 +85,20 @@ def train_gpt(args): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_type: {solver_type} | model_type: {model_type}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_type: {solver_type} | model_type: {model_type}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run(rank, world_size, port, args): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") train_gpt(args) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index 84b02633e775..f3d35dd9042b 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -29,8 +29,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -51,14 +51,14 @@ def main(): logger = get_dist_logger() config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) if FP16: - model = GPT2LMHeadModel(config=config).half().to('cuda') + model = GPT2LMHeadModel(config=config).half().to("cuda") else: - model = GPT2LMHeadModel(config=config).to('cuda') + model = GPT2LMHeadModel(config=config).to("cuda") global_numel = sum([p.numel() for p in model.parameters()]) meta_input_sample = { - 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), - 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + "input_ids": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), + "attention_mask": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), } gm, solution = autoparallelize(model, meta_input_sample, return_solution=True) @@ -72,7 +72,7 @@ def main(): criterion = GPTLMLoss() optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) torch.cuda.synchronize() model.train() @@ -89,10 +89,11 @@ def main(): torch.cuda.synchronize() step_time = time() - start logger.info( - f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', - ranks=[0]) + f"[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}", + ranks=[0], + ) torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py index 95feaec38c26..ad9a19777284 100644 --- a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py +++ b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py @@ -8,7 +8,6 @@ class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -30,15 +29,15 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -64,7 +63,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -72,7 +71,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -93,7 +92,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -106,10 +105,9 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - qkv = self.c_attn(hidden_states) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) @@ -117,7 +115,6 @@ def forward( class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -152,7 +149,6 @@ def forward( class GPT2Model(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -189,11 +185,9 @@ def forward( # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -217,7 +211,6 @@ def forward( class GPT2LMHeadModel(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) @@ -241,7 +234,6 @@ def forward( class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py index c31b3fa6d103..47cc87980556 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py +++ b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py @@ -4,22 +4,25 @@ ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -70,4 +73,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index 749243e57836..17692e90a03c 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -3,41 +3,34 @@ from functools import partial import torch -from model_zoo import model_builder from torch import nn -from tqdm import tqdm from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import ( - avgnode_split_pass, - gpipe_dp_split_pass, - split_with_split_nodes_pass, -) +from colossalai.fx.passes.adding_split_node_pass import gpipe_dp_split_pass, split_with_split_nodes_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam +from model_zoo import model_builder def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29011') - parser.add_argument('--num_worker_threads', type=int, default=128) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29011") + parser.add_argument("--num_worker_threads", type=int, default=128) return parser.parse_args() class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -63,16 +56,16 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): # Create annotated model which is noted where to be splitted. def get_annotated_model(model, data_kwargs, num_stages, num_microbatches): tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()]) + interp_meta_args = tuple([v.to("meta") for k, v in data_kwargs.items()]) interp = MetaInfoProp(gm) interp.run(*interp_meta_args) - #annotated_model = avgnode_split_pass(gm, num_stages) - annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01) + # annotated_model = avgnode_split_pass(gm, num_stages) + annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode="block", block_limit=0.01) return annotated_model @@ -83,7 +76,7 @@ def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, n topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -107,8 +100,10 @@ def run_master(args): disable_existing_loggers() logger = get_dist_logger() - logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", - ranks=[0]) + logger.info( + f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", + ranks=[0], + ) torch.manual_seed(123) @@ -117,26 +112,28 @@ def run_master(args): # warm up pipeline fx partition input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} + warmup_data_kwargs = {"input_ids": input_ids, "attention_mask": attn_mask} # create model - logger.info(f'start model_builder') + logger.info(f"start model_builder") model = model_builder(model_type)(checkpoint=False) - logger.info(f'end model_builder') + logger.info(f"end model_builder") # set 1f1b pipeline engine - pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=1, - criterion=criterion, - metric=None, - checkpoint=False) + pp_engine = FillDrainPipelineEngine( + partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False, + ) partition_numels = pp_engine.remote_numels() for rank, numel in partition_numels.items(): - logger.info(f'{rank=} numel in the partition:{numel}') + logger.info(f"{rank=} numel in the partition:{numel}") # build optim pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) @@ -145,7 +142,7 @@ def run_master(args): for n in range(NUM_STEPS): # we just use randomly generated data here input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - batch = {'input_ids': input_ids, 'attention_mask': attn_mask} + batch = {"input_ids": input_ids, "attention_mask": attn_mask} start = time.time() outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) @@ -175,6 +172,6 @@ def run_master(args): logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() rpc_run(args, run_master) diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py index 65124d9e4884..0f4517549db2 100644 --- a/examples/language/gpt/gemini/commons/model_zoo.py +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -4,22 +4,25 @@ ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -82,4 +85,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7bd098c1927c..7ed5fdb92b35 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -6,7 +6,6 @@ class DummyProfiler: - def __init__(self): self.step_number = 0 @@ -27,11 +26,13 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): if enable_flag: - return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True) + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index f9d30fd15c7b..88b76c654b1d 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -27,7 +27,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( @@ -54,7 +54,6 @@ def parse_args(): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -74,8 +73,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_model_size(model: nn.Module): @@ -91,11 +90,11 @@ def model_size_formatter(numel: int) -> str: MB_SIZE = 10**6 KB_SIZE = 10**3 if numel >= GB_SIZE: - return f'{numel / GB_SIZE:.1f}B' + return f"{numel / GB_SIZE:.1f}B" elif numel >= MB_SIZE: - return f'{numel / MB_SIZE:.1f}M' + return f"{numel / MB_SIZE:.1f}M" elif numel >= KB_SIZE: - return f'{numel / KB_SIZE:.1f}K' + return f"{numel / KB_SIZE:.1f}K" else: return str(numel) @@ -103,7 +102,7 @@ def model_size_formatter(numel: int) -> str: def set_cpu_maximum_parallelism(): conf_str = torch.__config__.parallel_info() inter_str = conf_str.split("hardware_concurrency() : ")[1] - max_concurrency = inter_str.split('\n')[0] + max_concurrency = inter_str.split("\n")[0] os.environ["OMP_NUM_THREADS"] = max_concurrency print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") @@ -130,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) @@ -159,10 +158,9 @@ def main(): plugin = None if args.distplan.startswith("CAI_ZeRO"): - plugin = LowLevelZeroPlugin(stage=zero_stage, - reduce_bucket_size_in_m=12, - overlap_communication=True, - verbose=True) + plugin = LowLevelZeroPlugin( + stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True + ) elif args.distplan == "CAI_Gemini": plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: @@ -171,7 +169,7 @@ def main(): # build a highly optimized gpu/cpu optimizer optimizer = HybridAdam(model.parameters(), lr=1e-3) - logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init optim, "), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() @@ -180,6 +178,7 @@ def main(): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) else: @@ -191,7 +190,7 @@ def main(): # model is shared after TP numel = get_model_size(model) logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) @@ -213,19 +212,19 @@ def train_step(): torch.cuda.synchronize() fwd_end = time() fwd_time = fwd_end - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0]) booster.backward(loss, optimizer) torch.cuda.synchronize() bwd_end = time() bwd_time = bwd_end - fwd_end - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0]) optimizer.step() torch.cuda.synchronize() optim_time = time() - bwd_end step_time = time() - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]) step_tflops = get_tflops_func(step_time) logger.info( @@ -235,10 +234,9 @@ def train_step(): if n >= WARMUP_STEPS: tflops_list.append(step_tflops) - demo_profiler = get_profile_context(PROF_FLAG, - WARMUP_STEPS, - NUM_STEPS - WARMUP_STEPS, - save_dir=f"profile/{get_time_stamp()}-demo") + demo_profiler = get_profile_context( + PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo" + ) with demo_profiler as prof: for n in range(NUM_STEPS): @@ -251,5 +249,5 @@ def train_step(): torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/language/gpt/hybridparallelism/data.py +++ b/examples/language/gpt/hybridparallelism/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 03e5ec91b3fe..62804eff8ea5 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -1,5 +1,4 @@ import argparse -from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -7,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Adam, Optimizer +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm @@ -17,7 +16,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -104,7 +102,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master() and results is not None: - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results @@ -115,13 +113,20 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, - train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() total_step = len(train_dataloader) @@ -129,22 +134,21 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) - with tqdm(range(total_step), - desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + with tqdm( + range(total_step), + desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", + disable=not (coordinator.is_master() or is_pp_last_stage), + ) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(train_dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data) @@ -152,7 +156,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -164,24 +168,26 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, default="gpt2", help="only gpt2 now", ) - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() - if args.model_type == 'gpt2': + if args.model_type == "gpt2": model_name = "gpt2" else: raise RuntimeError @@ -198,36 +204,35 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=1, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - zero_stage=1, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) booster = Booster(plugin=plugin, **booster_kwargs) # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -275,10 +280,9 @@ def _criterion(outputs, inputs): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) # ============================== # Train model @@ -286,14 +290,22 @@ def _criterion(outputs, inputs): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, - data_builder.eval_splits, booster, coordinator) + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py index 7bf53303948a..bc3dcb85cf1a 100644 --- a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py @@ -11,8 +11,10 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**5)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**5), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), + tensor=dict(size=2, mode="1d"), ) diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py index 9f9816b3004f..7413764dad81 100644 --- a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py @@ -11,8 +11,10 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**16)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**16), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None + tensor=dict(size=2, mode="1d"), # for the current model implementation, mode can only be 1D or None ) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py index fdfc57e9ba22..e61f73fd9eba 100644 --- a/examples/language/gpt/titans/dataset/webtext.py +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -11,12 +11,11 @@ @DATASETS.register_module class WebtextDataset(Dataset): - def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: super().__init__() if path is not None: root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + encoded_data_cache_path = os.path.join(root, f"gpt_webtext_{seq_len}.pt") if os.path.isfile(encoded_data_cache_path): seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) if seq_len_ == seq_len: @@ -26,12 +25,12 @@ def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: raw_data = [] with open(path) as f: for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + raw_data.append(json.loads(line)["text"]) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors="pt") + self.data = encoded_data["input_ids"] + self.attention_mask = encoded_data["attention_mask"] else: self.data = torch.randint(0, 50257, (10240, seq_len)) self.attention_mask = torch.ones_like(self.data) @@ -40,4 +39,4 @@ def __len__(self): return len(self.data) def __getitem__(self, index): - return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] + return {"input_ids": self.data[index], "attention_mask": self.attention_mask[index]}, self.data[index] diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index a6c80394c50f..b2e3f71a5387 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -1,7 +1,6 @@ import torch import torch.nn.init as init from torch import Tensor -from torch import distributed as dist from torch import nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter @@ -12,7 +11,7 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide -from colossalai.legacy.registry import LAYERS, LOSSES, MODELS +from colossalai.legacy.registry import LAYERS, LOSSES from colossalai.utils import get_current_device @@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes=0, - dtype=torch.float): + def __init__( + self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float + ): super(VocabParallelEmbedding, self).__init__() self.hidden_size = hidden_size @@ -44,11 +39,11 @@ def __init__(self, # Word embeddings (parallel). self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -56,7 +51,7 @@ def __init__(self, # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype) # Initialize the token-type embeddings. @@ -83,9 +78,9 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -112,19 +107,16 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -138,9 +130,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -150,9 +141,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -163,15 +153,14 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class VocabParallelEmbedding1D(torch.nn.Module): @@ -193,37 +182,41 @@ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None): # Set the details for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = gpc.tensor_parallel_size # Divide the weight matrix along the vocabulary dimension. - self.vocab_start_index, self.vocab_end_index = \ - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), - self.tensor_model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 @@ -234,7 +227,6 @@ def forward(self, input_): @LOSSES.register_module class vocab_parallel_cross_entropy(nn.Module): - def __init__(self): super().__init__() @@ -242,20 +234,19 @@ def forward(self, vocab_parallel_logits, target): """Helper function for the cross entropy.""" vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous() target = target[..., 1:].contiguous() - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), - target.view(-1)) + return _VocabParallelCrossEntropy.apply( + vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1) + ) class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod def forward(ctx, vocab_parallel_logits, target): - # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Subtract the maximum value. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) @@ -282,17 +273,17 @@ def forward(ctx, vocab_parallel_logits, target): predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits @@ -304,7 +295,6 @@ def forward(ctx, vocab_parallel_logits, target): @staticmethod def backward(ctx, grad_output): - # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors @@ -316,7 +306,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) @@ -326,8 +316,8 @@ def backward(ctx, grad_output): class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): @@ -393,11 +383,11 @@ def __init__( # Word embeddings (parallel). self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -405,7 +395,7 @@ def __init__( # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -432,9 +422,9 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -460,19 +450,16 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -486,9 +473,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -498,9 +484,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -511,15 +496,14 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class HiddenParallelEmbedding1D(torch.nn.Module): @@ -542,21 +526,21 @@ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx # Set the details for compatibility. self.padding_idx = padding_idx self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): - # Get the embeddings. - output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ) # Reduce across all the model parallel GPUs. output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -584,11 +568,9 @@ def __init__( # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) # (hidden_size/q, vocab_size) self.synced_embed = False - self.head = Linear1D_Row(in_features=embed_dim, - out_features=vocab_size, - bias=False, - dtype=dtype, - parallel_input=False) + self.head = Linear1D_Row( + in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False + ) def forward(self, x: Tensor) -> Tensor: if self.synced_embed: diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py index 746acbf7dccd..f8e2f42e11cb 100644 --- a/examples/language/gpt/titans/model/gpt1d.py +++ b/examples/language/gpt/titans/model/gpt1d.py @@ -18,18 +18,21 @@ from colossalai.utils import checkpoint __all__ = [ - 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D' + "GPTMLP1D", + "GPTSelfAttention1D", + "GPTTransformerLayer1D", + "FusedGPTSelfAttention1D", + "FusedGPTTransformerLayer1D", ] class GPTMLP1D(ParallelLayer): - def __init__( self, in_features: int, mlp_ratio: int, - act_func: str = 'gelu', - dropout_prob: float = 0., + act_func: str = "gelu", + dropout_prob: float = 0.0, dtype=None, checkpoint: bool = False, skip_bias_add: bool = False, @@ -82,7 +85,6 @@ def forward(self, hidden_states: Tensor) -> Tensor: class GenericGPTSelfAttention1D(ParallelLayer): - def __init__( self, hidden_size: int, @@ -118,8 +120,10 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: query_key_value = self.query_key_value(hidden_states) - new_qkv_shape = query_key_value.shape[:-1] + \ - (self.num_attention_heads_per_partition, 3 * self.attention_head_size) + new_qkv_shape = query_key_value.shape[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.attention_head_size, + ) query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.permute((0, 2, 1, 3)) query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1) @@ -152,28 +156,32 @@ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: class GPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) self.softmax = nn.Softmax(dim=-1) max_positions = max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -181,7 +189,7 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay attention_scores = attention_scores / math.sqrt(self.attention_head_size) # causal mask query_length, key_length = query_layer.size(-2), key_layer.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores)) if attention_mask is not None: # Apply the attention mask @@ -191,50 +199,56 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) - self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - mask_func=None, - softmax_in_fp32=True, - scale=math.sqrt(self.attention_head_size)) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) + self.softmax = kernel.FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size), + ) def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): return self.softmax(attention_scores, attention_mask) class GenericGPTTransformerLayer1D(ParallelLayer): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4.0, - attention_dropout_prob: float = 0., - hidden_dropout_prob: float = 0., - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - attention=None, - layer_norm=None): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4.0, + attention_dropout_prob: float = 0.0, + hidden_dropout_prob: float = 0.0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + attention=None, + layer_norm=None, + ): super().__init__() self.checkpoint = checkpoint self.dtype = dtype @@ -288,62 +302,68 @@ def forward(self, hidden_states, attention_mask): class GPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = GPTSelfAttention1D layer_norm = nn.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = FusedGPTSelfAttention1D layer_norm = kernel.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py index a9da246faf82..83158cb44e0c 100644 --- a/examples/language/gpt/titans/model/pipeline_gpt1d.py +++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py @@ -17,17 +17,16 @@ from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D __all__ = [ - 'GPT2_small_pipeline_1D', - 'GPT2_exlarge_pipeline_1D', - 'GPT3_pipeline_1D', - 'GPT2_exlarge_pipeline_hybrid', - 'GPT2_small_pipeline_hybrid', - 'GPT3_pipeline_hybrid', + "GPT2_small_pipeline_1D", + "GPT2_exlarge_pipeline_1D", + "GPT3_pipeline_1D", + "GPT2_exlarge_pipeline_hybrid", + "GPT2_small_pipeline_hybrid", + "GPT3_pipeline_hybrid", ] class GenericPipelineGPT(nn.Module): - def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None: super().__init__() self.embedding = embedding @@ -44,7 +43,7 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): batch_size = hidden_states.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) @@ -54,25 +53,26 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): class PipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -83,19 +83,24 @@ def __init__(self, head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - GPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + GPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -103,25 +108,26 @@ def __init__(self, class FusedPipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -132,19 +138,24 @@ def __init__(self, head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - FusedGPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + FusedGPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -153,7 +164,7 @@ def __init__(self, def forward(self, hidden_states=None, input_ids=None, attention_mask=None): if self.embedding is not None: hidden_states = self.embedding(input_ids=input_ids) - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) if self.norm is not None: @@ -162,44 +173,48 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): class PipelineGPTHybrid(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None if first: - embedding = col_gpt.GPTEmbedding(hidden_size, - vocab_size, - max_position_embeddings, - dropout=embed_drop_rate, - dtype=dtype) - blocks = nn.ModuleList([ - col_gpt.GPTBlock(hidden_size, - num_attention_heads, - mlp_ratio=mlp_ratio, - attention_dropout=attn_drop_rate, - dropout=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - activation=nn.functional.gelu) for _ in range(num_layers) - ]) + embedding = col_gpt.GPTEmbedding( + hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype + ) + blocks = nn.ModuleList( + [ + col_gpt.GPTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attn_drop_rate, + dropout=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + activation=nn.functional.gelu, + ) + for _ in range(num_layers) + ] + ) if last: norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) # head = col_gpt.GPTLMHead(vocab_size=vocab_size, @@ -215,7 +230,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() if gpc.is_initialized(ParallelMode.PIPELINE): @@ -233,10 +248,10 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = end - start - kwargs['first'] = start == 0 - kwargs['last'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + kwargs["last"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) if wrapper is not None: @@ -253,70 +268,82 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to numel = 0 for _, param in model.named_parameters(recurse=True): numel += param.numel() - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB") return model -def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs): +def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs): model = FusedPipelineGPT1D if fused else PipelineGPT1D return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs) -def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs) def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg) def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg) def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg) def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg) diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 3ed18b21fff5..b9d802f01cc9 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -14,7 +14,7 @@ from colossalai.legacy.zero.init_ctx import ZeroInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import LinearWarmupLR -from colossalai.utils import colo_set_process_memory_fraction, is_using_pp +from colossalai.utils import is_using_pp from colossalai.utils.timer import MultiTimer @@ -30,8 +30,8 @@ def calc_local_model_size(model: torch.nn.Module): def main(): parser = colossalai.get_default_parser() - parser.add_argument('--from_torch', default=False, action='store_true') - parser.add_argument('--use_dummy_dataset', default=False, action='store_true') + parser.add_argument("--from_torch", default=False, action="store_true") + parser.add_argument("--use_dummy_dataset", default=False, action="store_true") args = parser.parse_args() disable_existing_loggers() if args.from_torch: @@ -40,28 +40,27 @@ def main(): colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() - data_path = None if args.use_dummy_dataset else os.environ['DATA'] - logger.info(f'Build data loader from path {data_path}', ranks=[0]) + data_path = None if args.use_dummy_dataset else os.environ["DATA"] + logger.info(f"Build data loader from path {data_path}", ranks=[0]) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - - logger.info('Build model', ranks=[0]) + train_dataloader = utils.get_dataloader( + train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True + ) + + logger.info("Build model", ranks=[0]) use_pipeline = is_using_pp() - use_interleaved = hasattr(gpc.config.model, 'num_chunks') - use_zero3 = hasattr(gpc.config, 'zero') + use_interleaved = hasattr(gpc.config.model, "num_chunks") + use_zero3 = hasattr(gpc.config, "zero") ctx = contextlib.nullcontext() if use_zero3: - ctx = ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True) + ctx = ZeroInitContext( + target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True, + ) with ctx: - model = gpc.config.model.pop('type')(**gpc.config.model) + model = gpc.config.model.pop("type")(**gpc.config.model) if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): model = nn.ModuleList([model]) @@ -70,25 +69,31 @@ def main(): else: numel = calc_local_model_size(model) - tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \ - * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) - - criterion = getattr(gpc.config, 'loss_fn', None) + tflop = ( + numel + * gpc.config.BATCH_SIZE + * gpc.config.SEQ_LEN + * gpc.get_world_size(ParallelMode.MODEL) + * gpc.get_world_size(ParallelMode.DATA) + * 8 + / (1024**4) + ) + + criterion = getattr(gpc.config, "loss_fn", None) if criterion is not None: criterion = criterion.type() else: criterion = GPTLMLoss() - logger.info('Build optimizer', ranks=[0]) - optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer) + logger.info("Build optimizer", ranks=[0]) + optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer) lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) - engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader=train_dataloader, - lr_scheduler=lr_scheduler) - global_batch_size = gpc.config.BATCH_SIZE * \ - gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) - logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + engine, train_dataloader, _, lr_scheduler = colossalai.initialize( + model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) + global_batch_size = ( + gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + ) + logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0]) timier = MultiTimer() trainer = Trainer(engine=engine, logger=logger, timer=timier) hook_list = [ @@ -98,16 +103,18 @@ def main(): hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), hooks.LogMetricByStepHook(), hooks.LogMemoryByEpochHook(logger), - # hooks.LogMemoryByEpochHook(logger), - # hooks.LogTimingByEpochHook(timer, logger), + # hooks.LogMemoryByEpochHook(logger), + # hooks.LogTimingByEpochHook(timer, logger), ] - trainer.fit(train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_interval=1, - hooks=hook_list, - display_progress=True, - return_output_label=False) + trainer.fit( + train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py index 15f76647c87b..2b2356b18b70 100644 --- a/examples/language/llama2/attn.py +++ b/examples/language/llama2/attn.py @@ -9,12 +9,14 @@ SUPPORT_FLASH2 = False try: import xformers.ops as xops + SUPPORT_XFORMERS = True except ImportError: pass try: from flash_attn import flash_attn_func + SUPPORT_FLASH2 = True except ImportError: pass @@ -62,10 +64,9 @@ def llama_flash_attention( if SUPPORT_FLASH2: attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) else: - attn_output = xops.memory_efficient_attention(query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask()) + attn_output = xops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 1b947cef9080..ce13ebbf617d 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -25,21 +25,22 @@ # ============================== MODEL_CONFIGS = { - '7b': - LlamaConfig(max_position_embeddings=4096), - '13b': - LlamaConfig(hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096), - '70b': - LlamaConfig(hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8), + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), } @@ -48,31 +49,31 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size') - parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run') - parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-w', - '--warmup_ratio', - type=float, - default=0.8, - help='warm up ratio of non-model data. Only for gemini-auto') - parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb') - parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers') - parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini') - parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini') - parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini') - parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size') - parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size') - parser.add_argument('--mbs', type=int, default=1) - parser.add_argument('--zero', type=int, default=0) + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1) + parser.add_argument("--zero", type=int, default=0) args = parser.parse_args() colossalai.launch_from_torch({}) @@ -85,56 +86,67 @@ def empty_init(): # Initialize Booster # ============================== use_empty_init = True - if args.plugin == 'gemini': - plugin = GeminiPlugin(precision='bf16', - shard_param_frac=args.shard_param_frac, - offload_optim_frac=args.offload_optim_frac, - offload_param_frac=args.offload_param_frac) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio) - elif args.plugin == 'fsdp': + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision="bf16", + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio) + elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( - mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), param_init_fn=empty_init(), ) else: - plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16)) - elif args.plugin == 'fsdp_cpu': + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ) + ) + elif args.plugin == "fsdp_cpu": if use_empty_init: plugin = TorchFSDPPlugin( - mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), ) else: - plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), - cpu_offload=CPUOffload(offload_params=True)) - elif args.plugin == '3d': - plugin = HybridParallelPlugin(tp_size=args.tp, - pp_size=args.pp, - zero_stage=args.zero, - enable_fused_normalization=True, - num_microbatches=args.mbs, - precision='bf16') - elif args.plugin == '3d_cpu': - plugin = HybridParallelPlugin(tp_size=args.tp, - pp_size=args.pp, - zero_stage=args.zero, - cpu_offload=True, - enable_fused_normalization=True, - num_microbatches=args.mbs, - initial_scale=2**8, - precision='bf16') + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + enable_fused_normalization=True, + num_microbatches=args.mbs, + precision="bf16", + ) + elif args.plugin == "3d_cpu": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=True, + num_microbatches=args.mbs, + initial_scale=2**8, + precision="bf16", + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -144,17 +156,19 @@ def empty_init(): dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size config = MODEL_CONFIGS[args.config] - dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size, - max_length=args.max_length, - vocab_size=config.vocab_size) + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # ============================== # Initialize Model and Optimizer # ============================== - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, - (GeminiPlugin, HybridParallelPlugin)) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -163,38 +177,36 @@ def empty_init(): model.gradient_checkpointing_enable() if args.xformers: - assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed' + assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') - performance_evaluator = PerformanceEvaluator(model_numel, - args.grad_checkpoint, - args.ignore_steps, - dp_world_size=dp_size) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size + ) optimizer = HybridAdam(model.parameters()) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) - for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()): + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): performance_evaluator.on_step_start(step) - booster.execute_pipeline(data_iter, - model, - criterion=lambda outputs, inputs: outputs[0], - optimizer=optimizer, - return_loss=False) + booster.execute_pipeline( + data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False + ) optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) else: - for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())): + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] @@ -204,8 +216,8 @@ def empty_init(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index 25d0e1bd9f46..a438833e1680 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -12,21 +12,22 @@ class StatefulDistributedSampler(DistributedSampler): - - def __init__(self, - dataset: Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - seed: int = 0, - drop_last: bool = False) -> None: + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) self.start_index: int = 0 def __iter__(self) -> Iterator: iterator = super().__iter__() indices = list(iterator) - indices = indices[self.start_index:] + indices = indices[self.start_index :] return iter(indices) def __len__(self) -> int: @@ -36,15 +37,17 @@ def set_start_index(self, start_index: int) -> None: self.start_index = start_index -def prepare_dataloader(dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - process_group: Optional[ProcessGroup] = None, - **kwargs): +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. @@ -68,10 +71,9 @@ def prepare_dataloader(dataset, """ _kwargs = kwargs.copy() process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler(dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle) + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) # Deterministic dataloader def seed_worker(worker_id): @@ -80,28 +82,29 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def load_json(file_path: str): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return json.load(f) def save_json(data, file_path: str): - with open(file_path, 'w') as f: + with open(file_path, "w") as f: json.dump(data, f, indent=4) class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length @@ -113,7 +116,7 @@ def __len__(self): def __getitem__(self, idx): return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], } diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 0efbf193c9a9..33aa1d33e6ba 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -39,20 +39,20 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample['prompt'] + sample['completion'] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + texts = [sample["prompt"] + sample["completion"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) data = {k: v.cuda() for k, v in data.items()} - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -62,30 +62,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: return tensor -def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, - batch_size: int, coordinator: DistCoordinator, save_dir: str): - save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') - os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { - 'epoch': epoch, - 'step': step, - 'sample_start_index': step * batch_size, + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, } if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, 'running_states.json')) + save_json(running_states, os.path.join(save_dir, "running_states.json")) -def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, - load_dir: str) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, 'model')) - booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) - running_states = load_json(os.path.join(load_dir, 'running_states.json')) - return running_states['epoch'], running_states['step'], running_states['sample_start_index'] +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] def _criterion(outputs, inputs): @@ -97,27 +107,29 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune") - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path') - parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run') - parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') - parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') - parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') - parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') - parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') - parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') - parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') - parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') - parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") + parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") args = parser.parse_args() # ============================== @@ -129,36 +141,34 @@ def main(): # ============================== # Initialize Booster # ============================== - if args.plugin == 'gemini': + if args.plugin == "gemini": plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(precision=args.mixed_precision, - placement_policy='auto', - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2_cpu': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin(tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision='fp32', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -179,8 +189,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -188,57 +199,56 @@ def main(): # ============================== # Initialize Tokenizer, Dataset and Dataloader # ============================== - tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 tokenizer.pad_token = tokenizer.unk_token dataset = load_dataset(args.dataset, args.task_name) - train_ds = dataset['train'] - dataloader = prepare_dataloader(train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_finetune, - tokenizer=tokenizer, - max_length=args.max_length)) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), + ) if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) total_step = args.num_epochs * len(dataloader) - lr_scheduler = CosineAnnealingWarmupLR(optimizer, - total_steps=total_step, - warmup_steps=math.ceil(total_step * 0.03), - eta_min=0.1 * args.lr) - default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, - optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) torch.set_default_dtype(torch.float) booster.load_model(model, args.model_path) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) # load checkpoint if specified start_epoch = 0 start_step = 0 sampler_start_idx = 0 if args.load is not None: - coordinator.print_on_master('Loading checkpoint') + coordinator.print_on_master("Loading checkpoint") start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") num_steps_per_epoch = len(dataloader) @@ -249,19 +259,18 @@ def main(): step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) - with tqdm(range(step_nums), - desc=f'Epoch {epoch}', - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step) as pbar: + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) loss = outputs["loss"] else: batch = next(dataloader_iter) @@ -276,20 +285,29 @@ def main(): if not use_pipeline: all_reduce_mean(loss) if print_flag: - pbar.set_postfix({'loss': loss.item()}) - writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f'Saving checkpoint') - save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, - args.save_dir) - coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/model_utils.py b/examples/language/llama2/model_utils.py index 431ff5cfb446..63569bc61143 100644 --- a/examples/language/llama2/model_utils.py +++ b/examples/language/llama2/model_utils.py @@ -23,10 +23,10 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 711b99c54360..a57c1e0e9ae3 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -10,9 +10,9 @@ def divide(x: float, y: float) -> float: if y == 0: - return float('inf') - elif y == float('inf'): - return float('nan') + return float("inf") + elif y == float("inf"): + return float("nan") return x / y @@ -27,10 +27,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -41,7 +40,7 @@ def end(self) -> None: self.start_time = None def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class PerformanceEvaluator: @@ -56,11 +55,13 @@ class PerformanceEvaluator: ignore_episodes: The number of episodes to ignore when calculating the performance. """ - def __init__(self, - model_numel: int, - enable_grad_checkpoint: bool = False, - ignore_steps: int = 0, - dp_world_size: Optional[int] = None) -> None: + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: self.model_numel = model_numel self.enable_grad_checkpoint = enable_grad_checkpoint self.ignore_steps = ignore_steps @@ -96,7 +97,9 @@ def on_fit_end(self) -> None: mp_world_size = self.coordinator.world_size // self.dp_world_size avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size self.coordinator.print_on_master( - f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, ' - f'avg_throughput: {avg_throughput}') + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}" + ) self.coordinator.print_on_master( - f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}') + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + ) diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 0eeac4035401..6cc73b6265a4 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -29,21 +29,22 @@ from colossalai.utils import get_current_device MODEL_CONFIGS = { - '7b': - LlamaConfig(max_position_embeddings=4096), - '13b': - LlamaConfig(hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096), - '70b': - LlamaConfig(hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8), + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), } @@ -56,20 +57,20 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample['text'] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + texts = [sample["text"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) data = {k: v.cuda() for k, v in data.items()} - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -79,30 +80,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: return tensor -def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, - batch_size: int, coordinator: DistCoordinator, save_dir: str): - save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') - os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { - 'epoch': epoch, - 'step': step, - 'sample_start_index': step * batch_size, + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, } if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, 'running_states.json')) + save_json(running_states, os.path.join(save_dir, "running_states.json")) -def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, - load_dir: str) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, 'model')) - booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) - running_states = load_json(os.path.join(load_dir, 'running_states.json')) - return running_states['epoch'], running_states['step'], running_states['sample_start_index'] +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] def _criterion(outputs, inputs): @@ -114,31 +125,31 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-d', - '--dataset', - type=str, - default='togethercomputer/RedPajama-Data-1T-Sample', - help='Data set path') - parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') - parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') - parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') - parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') - parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') - parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') - parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') - parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') - parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') - parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument( + "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" + ) + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") args = parser.parse_args() # ============================== @@ -150,36 +161,34 @@ def main(): # ============================== # Initialize Booster # ============================== - if args.plugin == 'gemini': + if args.plugin == "gemini": plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(precision=args.mixed_precision, - placement_policy='auto', - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2_cpu': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin(tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision='fp32', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -197,27 +206,28 @@ def main(): # ============================== # Initialize Tokenizer, Dataset and Dataloader # ============================== - tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 tokenizer.pad_token = tokenizer.unk_token dataset = load_dataset(args.dataset) - train_ds = dataset['train'] - dataloader = prepare_dataloader(train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_pretrain, - tokenizer=tokenizer, - max_length=args.max_length)) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), + ) # ============================== # Initialize Model, Optimizer and LR Scheduler # ============================== config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -225,37 +235,36 @@ def main(): if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - lr_scheduler = CosineAnnealingWarmupLR(optimizer, - total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps, - eta_min=0.1 * args.lr) - default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, - optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) # load checkpoint if specified start_epoch = 0 start_step = 0 sampler_start_idx = 0 if args.load is not None: - coordinator.print_on_master('Loading checkpoint') + coordinator.print_on_master("Loading checkpoint") start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") num_steps_per_epoch = len(dataloader) @@ -266,19 +275,18 @@ def main(): step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) - with tqdm(range(step_nums), - desc=f'Epoch {epoch}', - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step) as pbar: + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) loss = outputs["loss"] else: batch = next(dataloader_iter) @@ -293,20 +301,29 @@ def main(): if not use_pipeline: all_reduce_mean(loss) if print_flag: - pbar.set_postfix({'loss': loss.item()}) - writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f'Saving checkpoint') - save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, - args.save_dir) - coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py index 77fa12bc8a0c..1ec19094e19e 100644 --- a/examples/language/opt/args.py +++ b/examples/language/opt/args.py @@ -2,36 +2,35 @@ def parse_demo_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="facebook/opt-350m", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-350m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model.bin", help="The path of your saved model after finetuning." + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", ) parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.1, - help="Ratio of warmup steps against total training steps.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps against total training steps." + ) parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") @@ -40,25 +39,28 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="facebook/opt-125m", - help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") diff --git a/examples/language/opt/data.py b/examples/language/opt/data.py index 6cfffb5fc95b..9b9cc59518ab 100644 --- a/examples/language/opt/data.py +++ b/examples/language/opt/data.py @@ -1,37 +1,38 @@ import torch -from torch.utils.data import Dataset from datasets import load_dataset +from torch.utils.data import Dataset class NetflixDataset(Dataset): - def __init__(self, tokenizer): - super().__init__() self.tokenizer = tokenizer self.input_ids = [] self.attn_masks = [] self.labels = [] - self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description'] + self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")[ + "description" + ] self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions]) for txt in self.txt_list: - encodings_dict = self.tokenizer('' + txt + '', - truncation=True, - max_length=self.max_length, - padding="max_length") - self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) - self.attn_masks.append(torch.tensor(encodings_dict['attention_mask'])) + encodings_dict = self.tokenizer( + "" + txt + "", truncation=True, max_length=self.max_length, padding="max_length" + ) + self.input_ids.append(torch.tensor(encodings_dict["input_ids"])) + self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"])) def __len__(self): return len(self.input_ids) def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx] - + def netflix_collator(data): - return {'input_ids': torch.stack([x[0] for x in data]), - 'attention_mask': torch.stack([x[1] for x in data]), - 'labels': torch.stack([x[0] for x in data])} + return { + "input_ids": torch.stack([x[0] for x in data]), + "attention_mask": torch.stack([x[1] for x in data]), + "labels": torch.stack([x[0] for x in data]), + } diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 90ed10ec7cca..d16c9fdf99ad 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -35,6 +35,7 @@ def get_data(batch_size, seq_len, vocab_size): def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -42,7 +43,6 @@ def colo_memory_cap(size_in_GB): def main(): - args = parse_benchmark_args() # Launch ColossalAI @@ -72,13 +72,13 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) @@ -101,11 +101,10 @@ def main(): start_time = time.time() for _ in range(args.max_train_steps): - input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) - loss = outputs['loss'] + loss = outputs["loss"] booster.backward(loss, optimizer) optimizer.step() @@ -123,7 +122,8 @@ def main(): f"plugin: {args.plugin}, " f"throughput: {throughput}, " f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + ranks=[0], + ) if __name__ == "__main__": diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index 7d6bdfb9f31c..fddbc1b408e7 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -1,5 +1,3 @@ -import time - import datasets import torch import transformers @@ -12,7 +10,6 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -29,7 +26,6 @@ def move_to_cuda(batch, device): def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator): - torch.cuda.synchronize() use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 @@ -39,22 +35,19 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b model.train() optimizer.zero_grad() dataloader = iter(dataloader) - with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: - + with tqdm( + range(total_step), desc=f"Epoch [{epoch + 1}]", disable=not (coordinator.is_master() or is_pp_last_stage) + ) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(dataloader) data = move_to_cuda(data) @@ -62,7 +55,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -70,7 +63,6 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b def main(): - args = parse_demo_args() # Launch ColossalAI @@ -98,34 +90,34 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=2, - pp_size=2, - num_microbatches=2, - enable_all_optimization=True, - zero_stage=0, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=2, + enable_all_optimization=True, + zero_stage=0, + precision="fp16", + initial_scale=1, + ) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = NetflixDataset(tokenizer) - dataloader = plugin.prepare_dataloader(dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=netflix_collator) + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=netflix_collator + ) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) @@ -133,9 +125,9 @@ def main(): # Set lr scheduler total_steps = len(dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=len(dataloader) * args.num_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch + ) # Define criterion def _criterion(outputs, inputs): @@ -145,11 +137,9 @@ def _criterion(outputs, inputs): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, dataloader=dataloader, criterion=_criterion, lr_scheduler=lr_scheduler + ) # Start finetuning logger.info(f"Start finetuning", ranks=[0]) diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh index 76c5e8601989..b94ee61f277c 100644 --- a/examples/language/opt/run_benchmark.sh +++ b/examples/language/opt/run_benchmark.sh @@ -24,7 +24,7 @@ torchrun \ --mem_cap ${MEMCAP} \ --plugin ${PLUGIN} \ --batch_size ${BS} - + done done done diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py index dc4f3d856fec..17251c2f4fb3 100644 --- a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -11,7 +11,6 @@ def exists(val): def eval_decorator(fn): - def inner(model, *args, **kwargs): was_training = model.training model.eval() @@ -34,7 +33,6 @@ def top_k(logits, thres=0.9): class AutoregressiveWrapper(nn.Module): - def __init__(self, net, max_seq_len=2048, pad_value=0): super().__init__() self.max_seq_len = max_seq_len diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py index c37974711e11..6be966d67790 100644 --- a/examples/language/palm/palm_pytorch/palm_pytorch.py +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -1,14 +1,13 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import einsum, matmul, nn +from torch import matmul, nn # normalization # they use layernorm without bias, something that pytorch does not offer class LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps @@ -24,7 +23,6 @@ def forward(self, x): class ParallelResidual(nn.Module): - def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) @@ -38,16 +36,15 @@ def forward(self, x): class RotaryEmbedding(nn.Module): - def __init__(self, dim): super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device) - #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) - #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + # freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + # freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) return torch.cat((freqs, freqs), dim=-1) @@ -69,7 +66,6 @@ def apply_rotary_pos_emb(pos, t): class SwiGLU(nn.Module): - def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x @@ -87,7 +83,6 @@ def FeedForward(dim, mult=4): # attention class Attention(nn.Module): - def __init__(self, dim, dim_head=64, heads=8): super().__init__() inner_dim = dim_head * heads @@ -160,7 +155,7 @@ def forward(self, x): # similarity - #sim = einsum("b h i d, b j d -> b h i j", q, k) + # sim = einsum("b h i d, b j d -> b h i j", q, k) sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2)) sim = sim.reshape(b, h, i, j) @@ -178,7 +173,7 @@ def forward(self, x): # aggregate values - #out = einsum("b h i j, b j d -> b h i d", attn, v) + # out = einsum("b h i j, b j d -> b h i d", attn, v) out = matmul(attn.reshape(b_, h_ * i_, j_), v) out = out.reshape(b_, h_, i_, d_) @@ -193,12 +188,17 @@ def forward(self, x): def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): net = nn.Sequential( - nn.Embedding(num_tokens, dim), *[ + nn.Embedding(num_tokens, dim), + *[ ParallelResidual( Attention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), - ) for _ in range(depth) - ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + ) + for _ in range(depth) + ], + LayerNorm(dim), + nn.Linear(dim, num_tokens, bias=False), + ) # they used embedding weight tied projection out to logits, not common, but works net[-1].weight = net[0].weight diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 526f791403ff..e7af88c55121 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -37,7 +37,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='colossalai', + default="colossalai", help="The distributed plan [colossalai, pytorch].", ) parser.add_argument( @@ -46,12 +46,14 @@ def parse_args(): default=1.0, help="Fraction of optimizer states to be offloaded. This is only used for gemini.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--batch_size", type=int, @@ -122,7 +124,6 @@ def generate_dataset(dummy_data: bool = False): class TextSamplerDataset(Dataset): - def __init__(self, data, seq_len): super().__init__() self.data = data @@ -130,7 +131,7 @@ def __init__(self, data, seq_len): def __getitem__(self, index): rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) - full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() return full_seq.cuda() def __len__(self): @@ -146,18 +147,18 @@ def __len__(self): # instantiate GPT-like decoder model booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext() + ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) @@ -182,7 +183,6 @@ def __len__(self): model.train() tflops_list = [] for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): - if args.distplan == "colossalai": optimizer.zero_grad() start = time() @@ -231,12 +231,12 @@ def __len__(self): # loss = model(next(val_loader)) # print(f"validation loss: {loss.item()}") - # if i % GENERATE_EVERY == 0: - # model.eval() - # inp = random.choice(val_dataset)[:-1] - # prime = decode_tokens(inp) - # print(f"%s \n\n %s", (prime, "*" * 100)) +# if i % GENERATE_EVERY == 0: +# model.eval() +# inp = random.choice(val_dataset)[:-1] +# prime = decode_tokens(inp) +# print(f"%s \n\n %s", (prime, "*" * 100)) - # sample = model.generate(inp[None, ...], GENERATE_LENGTH) - # output_str = decode_tokens(sample[0]) - # print(output_str) +# sample = model.generate(inp[None, ...], GENERATE_LENGTH) +# output_str = decode_tokens(sample[0]) +# print(output_str) diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md index 7b5668612818..a54c7b4da3bd 100644 --- a/examples/tutorial/README.md +++ b/examples/tutorial/README.md @@ -4,7 +4,7 @@ ## Introduction -Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 5a68aae18041..29101ce08434 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -20,20 +20,22 @@ def _benchmark(rank, world_size, port): only result in minor performance drop. So at last we might be able to find better training batch size for our model (combine with large batch training optimizer such as LAMB). """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = tm.resnet152() gm = symbolic_trace(model) raw_graph = deepcopy(gm.graph) peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] for batch_size in batch_sizes: batch_size = int(batch_size) - gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device="meta")) solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) gm.graph = solver.solve() - peak_mem, step_time = bench(gm, - torch.nn.CrossEntropyLoss(), - partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), - num_steps=5) + peak_mem, step_time = bench( + gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5, + ) peak_mems.append(peak_mem) through_puts.append(batch_size / step_time * 1.0e3) gm.graph = deepcopy(raw_graph) @@ -41,7 +43,7 @@ def _benchmark(rank, world_size, port): # print results print("===============benchmark summary================") for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): - print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + print(f"batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s") def auto_activation_checkpoint_batchsize_benchmark(): diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index aa5c47294a82..cd03a917912e 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -1,4 +1,3 @@ -import time from argparse import ArgumentParser from functools import partial @@ -8,7 +7,6 @@ from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium import colossalai -from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.testing import spawn @@ -19,37 +17,33 @@ def _benchmark(rank, world_size, port, args): The benchmark will sample in a range of memory budget for each model and output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - if args.model == 'resnet50': + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + if args.model == "resnet50": model = tm.resnet50() data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) gm = symbolic_trace(model) - gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device="meta")) loss = torch.nn.CrossEntropyLoss() else: model = gpt2_medium() data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) - data, mask = data_gen(device='meta')[0] - gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = data_gen(device="meta")[0] + gm = symbolic_trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = metainfo_trace(gm, data, mask) loss = GPTLMLoss() - free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 - start_factor = 4 if args.model == 'resnet50' else 10 + free_memory = 11000 * 1024**2 if args.model == "resnet50" else 56000 * 1024**2 + start_factor = 4 if args.model == "resnet50" else 10 # trace and benchmark - budgets, peak_hist, step_hist = bench_rotor(gm, - loss, - data_gen, - num_steps=5, - sample_points=15, - free_memory=free_memory, - start_factor=start_factor) + budgets, peak_hist, step_hist = bench_rotor( + gm, loss, data_gen, num_steps=5, sample_points=15, free_memory=free_memory, start_factor=start_factor + ) # print summary print("==============benchmark summary==============") for budget, peak, step in zip(budgets, peak_hist, step_hist): - print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + print(f"memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS") # plot valid results fig, axs = plt.subplots(1, 2, figsize=(16, 8)) @@ -57,14 +51,14 @@ def _benchmark(rank, world_size, port, args): # plot peak memory vs. budget memory axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) - axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle="--") axs[0].set_xlabel("Budget Memory (MB)") axs[0].set_ylabel("Peak Memory (MB)") axs[0].set_title("Peak Memory vs. Budget Memory") # plot relative step time vs. budget memory axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) - axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle="--") axs[1].set_xlabel("Peak Memory (MB)") axs[1].set_ylabel("Relative Step Time") axs[1].set_title("Step Time vs. Peak Memory") @@ -81,7 +75,7 @@ def auto_activation_checkpoint_benchmark(args): if __name__ == "__main__": parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") - parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + parser.add_argument("--model", type=str, default="gpt2", choices=["gpt2", "resnet50"]) args = parser.parse_args() auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 33aa5990f7c1..3c5b786b561a 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -17,14 +17,14 @@ def synthesize_data(): def main(): - colossalai.launch_from_torch(config='./config.py') + colossalai.launch_from_torch(config="./config.py") logger = get_dist_logger() # trace the model with meta data model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + input_sample = {"x": torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to("meta")} device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) @@ -88,8 +88,9 @@ def main(): logger.info( f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", - ranks=[0]) + ranks=[0], + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py index 69859f885ae6..96cfd49c6787 100644 --- a/examples/tutorial/auto_parallel/bench_utils.py +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -1,22 +1,19 @@ import time from copy import deepcopy -from functools import partial from typing import Callable, Tuple import numpy as np import torch import torch.nn as nn -import torchvision.models as tm from transformers import GPT2Config, GPT2LMHeadModel from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace -def bench(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5) -> Tuple[int, int]: +def bench( + gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5 +) -> Tuple[int, int]: """Benchmarking a given graph module Args: gm (torch.fx.GraphModule): The graph module to benchmark. @@ -28,7 +25,7 @@ def bench(gm: torch.fx.GraphModule, """ gm.train() gm.cuda() - step_time = float('inf') + step_time = float("inf") torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -58,13 +55,15 @@ def bench(gm: torch.fx.GraphModule, return peak_mem, step_time * 1.0e3 -def bench_rotor(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5, - sample_points: int = 20, - free_memory: int = torch.cuda.mem_get_info()[0], - start_factor: int = 4) -> Tuple[np.array, list, list]: +def bench_rotor( + gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4, +) -> Tuple[np.array, list, list]: """Auto Checkpoint Rotor Algorithm benchmarking Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. Args: @@ -88,7 +87,7 @@ def bench_rotor(gm: torch.fx.GraphModule, gm.graph = solver.solve(verbose=False) peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) except: - peak_memory, step_time = budget / 1024**2, float('inf') + peak_memory, step_time = budget / 1024**2, float("inf") peak_hist.append(peak_memory) step_hist.append(step_time) gm.graph = deepcopy(raw_graph) @@ -100,22 +99,27 @@ class GPTLMModel(nn.Module): GPT Model """ - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -152,7 +156,7 @@ def gpt2_6b(checkpoint=False): return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) -def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): +def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"): """ Generate random data for gpt2 benchmarking """ @@ -161,7 +165,7 @@ def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): return (input_ids, attention_mask), attention_mask -def data_gen_resnet(batch_size, shape, device='cuda:0'): +def data_gen_resnet(batch_size, shape, device="cuda:0"): """ Generate random data for resnet benchmarking """ diff --git a/examples/tutorial/auto_parallel/setup.py b/examples/tutorial/auto_parallel/setup.py index 6e6cff32ed23..94d5ec0c0e9e 100644 --- a/examples/tutorial/auto_parallel/setup.py +++ b/examples/tutorial/auto_parallel/setup.py @@ -1,13 +1,13 @@ from setuptools import find_packages, setup setup( - name='auto_parallel', - version='0.0.1', - description='', + name="auto_parallel", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], ) diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py index 5c6b6988ade5..78ea3d1e062e 100644 --- a/examples/tutorial/download_cifar10.py +++ b/examples/tutorial/download_cifar10.py @@ -5,9 +5,9 @@ def main(): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_root = os.path.join(dir_path, 'data') + data_root = os.path.join(dir_path, "data") dataset = CIFAR10(root=data_root, download=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py index 287f62aa7a90..15f9d0bc75ee 100644 --- a/examples/tutorial/hybrid_parallel/config.py +++ b/examples/tutorial/hybrid_parallel/config.py @@ -18,11 +18,11 @@ MLP_RATIO = 2 NUM_CLASSES = 10 CHECKPOINT = False -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token # parallel setting TENSOR_PARALLEL_SIZE = 2 -TENSOR_PARALLEL_MODE = '1d' +TENSOR_PARALLEL_MODE = "1d" parallel = dict( pipeline=2, @@ -33,4 +33,4 @@ clip_grad_norm = 1.0 # pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] +NUM_MICRO_BATCHES = parallel["pipeline"] diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py index 21a568168e33..95f1bf8ee17c 100644 --- a/examples/tutorial/hybrid_parallel/train.py +++ b/examples/tutorial/hybrid_parallel/train.py @@ -14,8 +14,7 @@ from colossalai.utils import is_using_pp -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -50,7 +49,7 @@ def main(): logger = get_dist_logger() logger.info("initialized distributed environment", ranks=[0]) - if hasattr(gpc.config, 'LOG_PATH'): + if hasattr(gpc.config, "LOG_PATH"): if gpc.get_global_rank() == 0: log_path = gpc.config.LOG_PATH if not os.path.exists(log_path): @@ -60,15 +59,17 @@ def main(): use_pipeline = is_using_pp() # create model - model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=10, - init_method='jax', - checkpoint=gpc.config.CHECKPOINT) + model_kwargs = dict( + img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method="jax", + checkpoint=gpc.config.CHECKPOINT, + ) if use_pipeline: pipelinable = PipelinableContext() @@ -102,16 +103,18 @@ def main(): optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -121,7 +124,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -133,5 +136,5 @@ def main(): gpc.destroy() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py index 6ebd8d68083d..dd114b5af86d 100644 --- a/examples/tutorial/large_batch_optimizer/train.py +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -10,8 +10,7 @@ from colossalai.nn.optimizer import Lamb, Lars -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -39,10 +38,9 @@ def __len__(self): def main(): # initialize distributed setting parser = colossalai.get_default_parser() - parser.add_argument('--optimizer', - choices=['lars', 'lamb'], - help="Choose your large-batch optimizer", - required=True) + parser.add_argument( + "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True + ) args = parser.parse_args() # launch from torch @@ -70,16 +68,18 @@ def main(): optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -89,7 +89,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -100,5 +100,5 @@ def main(): lr_scheduler.step() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/cifar_resnet/eval.py b/examples/tutorial/new_api/cifar_resnet/eval.py index 657708ec3ff2..526e41a2850f 100644 --- a/examples/tutorial/new_api/cifar_resnet/eval.py +++ b/examples/tutorial/new_api/cifar_resnet/eval.py @@ -1,7 +1,6 @@ import argparse import torch -import torch.nn as nn import torchvision import torchvision.transforms as transforms @@ -9,15 +8,15 @@ # Parse Arguments # ============================== parser = argparse.ArgumentParser() -parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") args = parser.parse_args() # ============================== # Prepare Test Dataset # ============================== # CIFAR-10 dataset -test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) # Data loader test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) @@ -26,7 +25,7 @@ # Load Model # ============================== model = torchvision.models.resnet18(num_classes=10).cuda() -state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") model.load_state_dict(state_dict) # ============================== @@ -45,4 +44,4 @@ total += labels.size(0) correct += (predicted == labels).sum().item() - print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index fe0dabf08377..6ae2d8b0412f 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -30,23 +30,19 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform transform_train = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) transform_test = transforms.ToTensor() # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -70,14 +66,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -91,7 +94,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -100,19 +103,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -136,13 +140,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -168,18 +172,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -191,14 +194,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 82a8f2ed97e4..226a4b320961 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -32,35 +32,37 @@ def vit_cifar(**kwargs): pretrained_cfg = _cfg(num_classes=10, input_size=(3, 32, 32), crop_pct=1.0) model_kwargs = dict(patch_size=4, embed_dim=512, depth=6, num_heads=8, drop_rate=0.1, mlp_ratio=1.0, **kwargs) - model = _create_vision_transformer('vit_cifar', pretrained_cfg=pretrained_cfg, **model_kwargs) + model = _create_vision_transformer("vit_cifar", pretrained_cfg=pretrained_cfg, **model_kwargs) return model def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(32), - transforms.ToTensor(), - transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), - ]) + transform_train = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) + transform_test = transforms.Compose( + [ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -84,14 +86,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -105,7 +114,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -114,19 +123,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -150,13 +160,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -182,19 +192,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -206,14 +214,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/glue_bert/data.py b/examples/tutorial/new_api/glue_bert/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/tutorial/new_api/glue_bert/data.py +++ b/examples/tutorial/new_api/glue_bert/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 63bdfc5d02cf..7d69dbc066b3 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -33,8 +33,14 @@ def move_to_cuda(batch): @torch.no_grad() -def evaluate(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, - eval_splits: List[str], coordinator: DistCoordinator): +def evaluate( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -58,7 +64,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master(): - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results if isinstance(test_dataloader, DataLoader): @@ -68,14 +74,21 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) @@ -89,7 +102,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler lr_scheduler.step() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -97,14 +110,16 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") args = parser.parse_args() # ============================== @@ -115,19 +130,19 @@ def main(): # local_batch_size = BATCH_SIZE // coordinator.world_size lr = LEARNING_RATE * coordinator.world_size - model_name = 'bert-base-uncased' + model_name = "bert-base-uncased" # ============================== # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -135,11 +150,9 @@ def main(): # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -185,14 +198,15 @@ def main(): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py index 1a0876ca8338..e4e857b264a0 100644 --- a/examples/tutorial/opt/inference/batch.py +++ b/examples/tutorial/opt/inference/batch.py @@ -1,5 +1,6 @@ +from typing import Any, Deque, Hashable, List, Tuple + import torch -from typing import List, Deque, Tuple, Hashable, Any from energonai import BatchManager, SubmitEntry, TaskEntry @@ -10,15 +11,15 @@ def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: self.pad_token_id = pad_token_id def _left_padding(self, batch_inputs): - max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) - outputs = {'input_ids': [], 'attention_mask': []} + max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs) + outputs = {"input_ids": [], "attention_mask": []} for inputs in batch_inputs: - input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] padding_len = max_len - len(input_ids) input_ids = [self.pad_token_id] * padding_len + input_ids attention_mask = [0] * padding_len + attention_mask - outputs['input_ids'].append(input_ids) - outputs['attention_mask'].append(attention_mask) + outputs["input_ids"].append(input_ids) + outputs["attention_mask"].append(attention_mask) for k in outputs: outputs[k] = torch.tensor(outputs[k]) return outputs, max_len @@ -26,7 +27,7 @@ def _left_padding(self, batch_inputs): @staticmethod def _make_batch_key(entry: SubmitEntry) -> tuple: data = entry.data - return (data['top_k'], data['top_p'], data['temperature']) + return (data["top_k"], data["top_p"], data["temperature"]) def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: entry = q.popleft() @@ -37,7 +38,7 @@ def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: break if self._make_batch_key(entry) != self._make_batch_key(q[0]): break - if q[0].data['max_tokens'] > entry.data['max_tokens']: + if q[0].data["max_tokens"] > entry.data["max_tokens"]: break e = q.popleft() batch.append(e.data) @@ -45,12 +46,12 @@ def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: inputs, max_len = self._left_padding(batch) trunc_lens = [] for data in batch: - trunc_lens.append(max_len + data['max_tokens']) - inputs['top_k'] = entry.data['top_k'] - inputs['top_p'] = entry.data['top_p'] - inputs['temperature'] = entry.data['temperature'] - inputs['max_tokens'] = max_len + entry.data['max_tokens'] - return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + trunc_lens.append(max_len + data["max_tokens"]) + inputs["top_k"] = entry.data["top_k"] + inputs["top_p"] = entry.data["top_p"] + inputs["temperature"] = entry.data["temperature"] + inputs["max_tokens"] = max_len + entry.data["max_tokens"] + return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens} def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: retval = [] diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py index 4d829e5d83bf..76ef9d8cb3d6 100644 --- a/examples/tutorial/opt/inference/benchmark/locustfile.py +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -1,15 +1,14 @@ from locust import HttpUser, task -from json import JSONDecodeError class GenerationUser(HttpUser): @task def generate(self): - prompt = 'Question: What is the longest river on the earth? Answer:' + prompt = "Question: What is the longest river on the earth? Answer:" for i in range(4, 9): - data = {'max_tokens': 2**i, 'prompt': prompt} - with self.client.post('/generation', json=data, catch_response=True) as response: + data = {"max_tokens": 2**i, "prompt": prompt} + with self.client.post("/generation", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py index 30febc44fbb3..1eb7dac2ea04 100644 --- a/examples/tutorial/opt/inference/cache.py +++ b/examples/tutorial/opt/inference/cache.py @@ -1,7 +1,7 @@ from collections import OrderedDict -from threading import Lock from contextlib import contextmanager -from typing import List, Any, Hashable, Dict +from threading import Lock +from typing import Any, Dict, Hashable, List class MissCacheError(Exception): diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py index cbfc2a22e7c0..6475284e535b 100644 --- a/examples/tutorial/opt/inference/opt_fastapi.py +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -4,20 +4,21 @@ from typing import Optional import uvicorn +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError from energonai import QueueFullError, launch_engine from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field from transformers import GPT2Tokenizer -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError - class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) @@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel): app = FastAPI() -@app.post('/generation') +@app.post("/generation") async def generate(data: GenerationTaskReq, request: Request): logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') key = (data.prompt, data.max_tokens) @@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(data.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = data.max_tokens - inputs['top_k'] = data.top_k - inputs['top_p'] = data.top_p - inputs['temperature'] = data.temperature + inputs["max_tokens"] = data.max_tokens + inputs["top_k"] = data.top_k + inputs["top_p"] = data.top_p + inputs["temperature"] = data.temperature try: uid = id(data) engine.submit(uid, inputs) @@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request): except QueueFullError as e: raise HTTPException(status_code=406, detail=e.args[0]) - return {'text': output} + return {"text": output} @app.on_event("shutdown") @@ -64,60 +65,72 @@ async def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) config = uvicorn.Config(app, host=args.http_host, port=args.http_port) server = uvicorn.Server(config=config) server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py index 8dab82622c59..7f591b9be111 100644 --- a/examples/tutorial/opt/inference/opt_server.py +++ b/examples/tutorial/opt/inference/opt_server.py @@ -1,33 +1,36 @@ -import logging import argparse +import logging import random -from torch import Tensor -from pydantic import BaseModel, Field from typing import Optional -from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B -from transformers import GPT2Tokenizer -from energonai import launch_engine, QueueFullError + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from pydantic import BaseModel, Field from sanic import Sanic from sanic.request import Request from sanic.response import json -from sanic_ext import validate, openapi -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError +from sanic_ext import openapi, validate +from torch import Tensor +from transformers import GPT2Tokenizer class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) -app = Sanic('opt') +app = Sanic("opt") -@app.post('/generation') +@app.post("/generation") @openapi.body(GenerationTaskReq) @validate(json=GenerationTaskReq) async def generate(request: Request, body: GenerationTaskReq): @@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(body.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = body.max_tokens - inputs['top_k'] = body.top_k - inputs['top_p'] = body.top_p - inputs['temperature'] = body.temperature + inputs["max_tokens"] = body.max_tokens + inputs["top_k"] = body.top_k + inputs["top_p"] = body.top_p + inputs["temperature"] = body.temperature try: uid = id(body) engine.submit(uid, inputs) @@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq): if cache is not None: cache.add(key, output) except QueueFullError as e: - return json({'detail': e.args[0]}, status=406) + return json({"detail": e.args[0]}, status=406) - return json({'text': output}) + return json({"text": output}) @app.after_server_stop @@ -65,58 +68,70 @@ def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md index bc3cba72df33..665c459fec69 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/README.md +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -43,4 +43,3 @@ Finally, you will get 8 files in `` with following checksums: 5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt ``` - diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py index a17ddd4fa173..36c9001fe3f1 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -14,42 +14,45 @@ def load_json(path: str): def parse_shape_info(flat_dir: str): - data = load_json(os.path.join(flat_dir, 'shape.json')) + data = load_json(os.path.join(flat_dir, "shape.json")) flat_info = defaultdict(lambda: defaultdict(list)) for k, shape in data.items(): - matched = re.match(r'decoder.layers.\d+', k) + matched = re.match(r"decoder.layers.\d+", k) if matched is None: - flat_key = 'flat_param_0' + flat_key = "flat_param_0" else: - flat_key = f'{matched[0]}.flat_param_0' - flat_info[flat_key]['names'].append(k) - flat_info[flat_key]['shapes'].append(shape) - flat_info[flat_key]['numels'].append(int(np.prod(shape))) + flat_key = f"{matched[0]}.flat_param_0" + flat_info[flat_key]["names"].append(k) + flat_info[flat_key]["shapes"].append(shape) + flat_info[flat_key]["numels"].append(int(np.prod(shape))) return flat_info def convert(flat_dir: str, output_dir: str, part: int): - flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') - output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') - flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt") + output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt") + flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json")) flat_sd = torch.load(flat_path) - print(f'Loaded flat state dict from {flat_path}') + print(f"Loaded flat state dict from {flat_path}") output_sd = {} for flat_key, param_meta in flat_meta.items(): - flat_param = flat_sd['model'][flat_key] - assert sum(param_meta['numels']) == flat_param.numel( + flat_param = flat_sd["model"][flat_key] + assert ( + sum(param_meta["numels"]) == flat_param.numel() ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' - for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + for name, shape, param in zip( + param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"]) + ): output_sd[name] = param.view(shape) torch.save(output_sd, output_path) - print(f'Saved unflat state dict to {output_path}') + print(f"Saved unflat state dict to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('flat_dir') - parser.add_argument('output_dir') - parser.add_argument('part', type=int) + parser.add_argument("flat_dir") + parser.add_argument("output_dir") + parser.add_argument("part", type=int) args = parser.parse_args() convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json index 59d285565cfd..ce70451cc4e5 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -1 +1,6944 @@ -{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file +{ + "flat_param_0": { + "names": [ + "decoder.embed_tokens.weight", + "decoder.embed_positions.weight", + "decoder.layer_norm.weight", + "decoder.layer_norm.bias" + ], + "shapes": [ + [ + 6284, + 12288 + ], + [ + 2050, + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 77217792, + 25190400, + 12288, + 12288 + ] + }, + "decoder.layers.0.flat_param_0": { + "names": [ + "decoder.layers.0.self_attn.qkv_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.bias", + "decoder.layers.0.self_attn.out_proj.weight", + "decoder.layers.0.self_attn.out_proj.bias", + "decoder.layers.0.self_attn_layer_norm.weight", + "decoder.layers.0.self_attn_layer_norm.bias", + "decoder.layers.0.fc1.weight", + "decoder.layers.0.fc1.bias", + "decoder.layers.0.fc2.weight", + "decoder.layers.0.fc2.bias", + "decoder.layers.0.final_layer_norm.weight", + "decoder.layers.0.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.1.flat_param_0": { + "names": [ + "decoder.layers.1.self_attn.qkv_proj.weight", + "decoder.layers.1.self_attn.qkv_proj.bias", + "decoder.layers.1.self_attn.out_proj.weight", + "decoder.layers.1.self_attn.out_proj.bias", + "decoder.layers.1.self_attn_layer_norm.weight", + "decoder.layers.1.self_attn_layer_norm.bias", + "decoder.layers.1.fc1.weight", + "decoder.layers.1.fc1.bias", + "decoder.layers.1.fc2.weight", + "decoder.layers.1.fc2.bias", + "decoder.layers.1.final_layer_norm.weight", + "decoder.layers.1.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.2.flat_param_0": { + "names": [ + "decoder.layers.2.self_attn.qkv_proj.weight", + "decoder.layers.2.self_attn.qkv_proj.bias", + "decoder.layers.2.self_attn.out_proj.weight", + "decoder.layers.2.self_attn.out_proj.bias", + "decoder.layers.2.self_attn_layer_norm.weight", + "decoder.layers.2.self_attn_layer_norm.bias", + "decoder.layers.2.fc1.weight", + "decoder.layers.2.fc1.bias", + "decoder.layers.2.fc2.weight", + "decoder.layers.2.fc2.bias", + "decoder.layers.2.final_layer_norm.weight", + "decoder.layers.2.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.3.flat_param_0": { + "names": [ + "decoder.layers.3.self_attn.qkv_proj.weight", + "decoder.layers.3.self_attn.qkv_proj.bias", + "decoder.layers.3.self_attn.out_proj.weight", + "decoder.layers.3.self_attn.out_proj.bias", + "decoder.layers.3.self_attn_layer_norm.weight", + "decoder.layers.3.self_attn_layer_norm.bias", + "decoder.layers.3.fc1.weight", + "decoder.layers.3.fc1.bias", + "decoder.layers.3.fc2.weight", + "decoder.layers.3.fc2.bias", + "decoder.layers.3.final_layer_norm.weight", + "decoder.layers.3.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.4.flat_param_0": { + "names": [ + "decoder.layers.4.self_attn.qkv_proj.weight", + "decoder.layers.4.self_attn.qkv_proj.bias", + "decoder.layers.4.self_attn.out_proj.weight", + "decoder.layers.4.self_attn.out_proj.bias", + "decoder.layers.4.self_attn_layer_norm.weight", + "decoder.layers.4.self_attn_layer_norm.bias", + "decoder.layers.4.fc1.weight", + "decoder.layers.4.fc1.bias", + "decoder.layers.4.fc2.weight", + "decoder.layers.4.fc2.bias", + "decoder.layers.4.final_layer_norm.weight", + "decoder.layers.4.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.5.flat_param_0": { + "names": [ + "decoder.layers.5.self_attn.qkv_proj.weight", + "decoder.layers.5.self_attn.qkv_proj.bias", + "decoder.layers.5.self_attn.out_proj.weight", + "decoder.layers.5.self_attn.out_proj.bias", + "decoder.layers.5.self_attn_layer_norm.weight", + "decoder.layers.5.self_attn_layer_norm.bias", + "decoder.layers.5.fc1.weight", + "decoder.layers.5.fc1.bias", + "decoder.layers.5.fc2.weight", + "decoder.layers.5.fc2.bias", + "decoder.layers.5.final_layer_norm.weight", + "decoder.layers.5.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.6.flat_param_0": { + "names": [ + "decoder.layers.6.self_attn.qkv_proj.weight", + "decoder.layers.6.self_attn.qkv_proj.bias", + "decoder.layers.6.self_attn.out_proj.weight", + "decoder.layers.6.self_attn.out_proj.bias", + "decoder.layers.6.self_attn_layer_norm.weight", + "decoder.layers.6.self_attn_layer_norm.bias", + "decoder.layers.6.fc1.weight", + "decoder.layers.6.fc1.bias", + "decoder.layers.6.fc2.weight", + "decoder.layers.6.fc2.bias", + "decoder.layers.6.final_layer_norm.weight", + "decoder.layers.6.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.7.flat_param_0": { + "names": [ + "decoder.layers.7.self_attn.qkv_proj.weight", + "decoder.layers.7.self_attn.qkv_proj.bias", + "decoder.layers.7.self_attn.out_proj.weight", + "decoder.layers.7.self_attn.out_proj.bias", + "decoder.layers.7.self_attn_layer_norm.weight", + "decoder.layers.7.self_attn_layer_norm.bias", + "decoder.layers.7.fc1.weight", + "decoder.layers.7.fc1.bias", + "decoder.layers.7.fc2.weight", + "decoder.layers.7.fc2.bias", + "decoder.layers.7.final_layer_norm.weight", + "decoder.layers.7.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.8.flat_param_0": { + "names": [ + "decoder.layers.8.self_attn.qkv_proj.weight", + "decoder.layers.8.self_attn.qkv_proj.bias", + "decoder.layers.8.self_attn.out_proj.weight", + "decoder.layers.8.self_attn.out_proj.bias", + "decoder.layers.8.self_attn_layer_norm.weight", + "decoder.layers.8.self_attn_layer_norm.bias", + "decoder.layers.8.fc1.weight", + "decoder.layers.8.fc1.bias", + "decoder.layers.8.fc2.weight", + "decoder.layers.8.fc2.bias", + "decoder.layers.8.final_layer_norm.weight", + "decoder.layers.8.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.9.flat_param_0": { + "names": [ + "decoder.layers.9.self_attn.qkv_proj.weight", + "decoder.layers.9.self_attn.qkv_proj.bias", + "decoder.layers.9.self_attn.out_proj.weight", + "decoder.layers.9.self_attn.out_proj.bias", + "decoder.layers.9.self_attn_layer_norm.weight", + "decoder.layers.9.self_attn_layer_norm.bias", + "decoder.layers.9.fc1.weight", + "decoder.layers.9.fc1.bias", + "decoder.layers.9.fc2.weight", + "decoder.layers.9.fc2.bias", + "decoder.layers.9.final_layer_norm.weight", + "decoder.layers.9.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.10.flat_param_0": { + "names": [ + "decoder.layers.10.self_attn.qkv_proj.weight", + "decoder.layers.10.self_attn.qkv_proj.bias", + "decoder.layers.10.self_attn.out_proj.weight", + "decoder.layers.10.self_attn.out_proj.bias", + "decoder.layers.10.self_attn_layer_norm.weight", + "decoder.layers.10.self_attn_layer_norm.bias", + "decoder.layers.10.fc1.weight", + "decoder.layers.10.fc1.bias", + "decoder.layers.10.fc2.weight", + "decoder.layers.10.fc2.bias", + "decoder.layers.10.final_layer_norm.weight", + "decoder.layers.10.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.11.flat_param_0": { + "names": [ + "decoder.layers.11.self_attn.qkv_proj.weight", + "decoder.layers.11.self_attn.qkv_proj.bias", + "decoder.layers.11.self_attn.out_proj.weight", + "decoder.layers.11.self_attn.out_proj.bias", + "decoder.layers.11.self_attn_layer_norm.weight", + "decoder.layers.11.self_attn_layer_norm.bias", + "decoder.layers.11.fc1.weight", + "decoder.layers.11.fc1.bias", + "decoder.layers.11.fc2.weight", + "decoder.layers.11.fc2.bias", + "decoder.layers.11.final_layer_norm.weight", + "decoder.layers.11.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.12.flat_param_0": { + "names": [ + "decoder.layers.12.self_attn.qkv_proj.weight", + "decoder.layers.12.self_attn.qkv_proj.bias", + "decoder.layers.12.self_attn.out_proj.weight", + "decoder.layers.12.self_attn.out_proj.bias", + "decoder.layers.12.self_attn_layer_norm.weight", + "decoder.layers.12.self_attn_layer_norm.bias", + "decoder.layers.12.fc1.weight", + "decoder.layers.12.fc1.bias", + "decoder.layers.12.fc2.weight", + "decoder.layers.12.fc2.bias", + "decoder.layers.12.final_layer_norm.weight", + "decoder.layers.12.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.13.flat_param_0": { + "names": [ + "decoder.layers.13.self_attn.qkv_proj.weight", + "decoder.layers.13.self_attn.qkv_proj.bias", + "decoder.layers.13.self_attn.out_proj.weight", + "decoder.layers.13.self_attn.out_proj.bias", + "decoder.layers.13.self_attn_layer_norm.weight", + "decoder.layers.13.self_attn_layer_norm.bias", + "decoder.layers.13.fc1.weight", + "decoder.layers.13.fc1.bias", + "decoder.layers.13.fc2.weight", + "decoder.layers.13.fc2.bias", + "decoder.layers.13.final_layer_norm.weight", + "decoder.layers.13.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.14.flat_param_0": { + "names": [ + "decoder.layers.14.self_attn.qkv_proj.weight", + "decoder.layers.14.self_attn.qkv_proj.bias", + "decoder.layers.14.self_attn.out_proj.weight", + "decoder.layers.14.self_attn.out_proj.bias", + "decoder.layers.14.self_attn_layer_norm.weight", + "decoder.layers.14.self_attn_layer_norm.bias", + "decoder.layers.14.fc1.weight", + "decoder.layers.14.fc1.bias", + "decoder.layers.14.fc2.weight", + "decoder.layers.14.fc2.bias", + "decoder.layers.14.final_layer_norm.weight", + "decoder.layers.14.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.15.flat_param_0": { + "names": [ + "decoder.layers.15.self_attn.qkv_proj.weight", + "decoder.layers.15.self_attn.qkv_proj.bias", + "decoder.layers.15.self_attn.out_proj.weight", + "decoder.layers.15.self_attn.out_proj.bias", + "decoder.layers.15.self_attn_layer_norm.weight", + "decoder.layers.15.self_attn_layer_norm.bias", + "decoder.layers.15.fc1.weight", + "decoder.layers.15.fc1.bias", + "decoder.layers.15.fc2.weight", + "decoder.layers.15.fc2.bias", + "decoder.layers.15.final_layer_norm.weight", + "decoder.layers.15.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.16.flat_param_0": { + "names": [ + "decoder.layers.16.self_attn.qkv_proj.weight", + "decoder.layers.16.self_attn.qkv_proj.bias", + "decoder.layers.16.self_attn.out_proj.weight", + "decoder.layers.16.self_attn.out_proj.bias", + "decoder.layers.16.self_attn_layer_norm.weight", + "decoder.layers.16.self_attn_layer_norm.bias", + "decoder.layers.16.fc1.weight", + "decoder.layers.16.fc1.bias", + "decoder.layers.16.fc2.weight", + "decoder.layers.16.fc2.bias", + "decoder.layers.16.final_layer_norm.weight", + "decoder.layers.16.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.17.flat_param_0": { + "names": [ + "decoder.layers.17.self_attn.qkv_proj.weight", + "decoder.layers.17.self_attn.qkv_proj.bias", + "decoder.layers.17.self_attn.out_proj.weight", + "decoder.layers.17.self_attn.out_proj.bias", + "decoder.layers.17.self_attn_layer_norm.weight", + "decoder.layers.17.self_attn_layer_norm.bias", + "decoder.layers.17.fc1.weight", + "decoder.layers.17.fc1.bias", + "decoder.layers.17.fc2.weight", + "decoder.layers.17.fc2.bias", + "decoder.layers.17.final_layer_norm.weight", + "decoder.layers.17.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.18.flat_param_0": { + "names": [ + "decoder.layers.18.self_attn.qkv_proj.weight", + "decoder.layers.18.self_attn.qkv_proj.bias", + "decoder.layers.18.self_attn.out_proj.weight", + "decoder.layers.18.self_attn.out_proj.bias", + "decoder.layers.18.self_attn_layer_norm.weight", + "decoder.layers.18.self_attn_layer_norm.bias", + "decoder.layers.18.fc1.weight", + "decoder.layers.18.fc1.bias", + "decoder.layers.18.fc2.weight", + "decoder.layers.18.fc2.bias", + "decoder.layers.18.final_layer_norm.weight", + "decoder.layers.18.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.19.flat_param_0": { + "names": [ + "decoder.layers.19.self_attn.qkv_proj.weight", + "decoder.layers.19.self_attn.qkv_proj.bias", + "decoder.layers.19.self_attn.out_proj.weight", + "decoder.layers.19.self_attn.out_proj.bias", + "decoder.layers.19.self_attn_layer_norm.weight", + "decoder.layers.19.self_attn_layer_norm.bias", + "decoder.layers.19.fc1.weight", + "decoder.layers.19.fc1.bias", + "decoder.layers.19.fc2.weight", + "decoder.layers.19.fc2.bias", + "decoder.layers.19.final_layer_norm.weight", + "decoder.layers.19.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.20.flat_param_0": { + "names": [ + "decoder.layers.20.self_attn.qkv_proj.weight", + "decoder.layers.20.self_attn.qkv_proj.bias", + "decoder.layers.20.self_attn.out_proj.weight", + "decoder.layers.20.self_attn.out_proj.bias", + "decoder.layers.20.self_attn_layer_norm.weight", + "decoder.layers.20.self_attn_layer_norm.bias", + "decoder.layers.20.fc1.weight", + "decoder.layers.20.fc1.bias", + "decoder.layers.20.fc2.weight", + "decoder.layers.20.fc2.bias", + "decoder.layers.20.final_layer_norm.weight", + "decoder.layers.20.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.21.flat_param_0": { + "names": [ + "decoder.layers.21.self_attn.qkv_proj.weight", + "decoder.layers.21.self_attn.qkv_proj.bias", + "decoder.layers.21.self_attn.out_proj.weight", + "decoder.layers.21.self_attn.out_proj.bias", + "decoder.layers.21.self_attn_layer_norm.weight", + "decoder.layers.21.self_attn_layer_norm.bias", + "decoder.layers.21.fc1.weight", + "decoder.layers.21.fc1.bias", + "decoder.layers.21.fc2.weight", + "decoder.layers.21.fc2.bias", + "decoder.layers.21.final_layer_norm.weight", + "decoder.layers.21.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.22.flat_param_0": { + "names": [ + "decoder.layers.22.self_attn.qkv_proj.weight", + "decoder.layers.22.self_attn.qkv_proj.bias", + "decoder.layers.22.self_attn.out_proj.weight", + "decoder.layers.22.self_attn.out_proj.bias", + "decoder.layers.22.self_attn_layer_norm.weight", + "decoder.layers.22.self_attn_layer_norm.bias", + "decoder.layers.22.fc1.weight", + "decoder.layers.22.fc1.bias", + "decoder.layers.22.fc2.weight", + "decoder.layers.22.fc2.bias", + "decoder.layers.22.final_layer_norm.weight", + "decoder.layers.22.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.23.flat_param_0": { + "names": [ + "decoder.layers.23.self_attn.qkv_proj.weight", + "decoder.layers.23.self_attn.qkv_proj.bias", + "decoder.layers.23.self_attn.out_proj.weight", + "decoder.layers.23.self_attn.out_proj.bias", + "decoder.layers.23.self_attn_layer_norm.weight", + "decoder.layers.23.self_attn_layer_norm.bias", + "decoder.layers.23.fc1.weight", + "decoder.layers.23.fc1.bias", + "decoder.layers.23.fc2.weight", + "decoder.layers.23.fc2.bias", + "decoder.layers.23.final_layer_norm.weight", + "decoder.layers.23.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.24.flat_param_0": { + "names": [ + "decoder.layers.24.self_attn.qkv_proj.weight", + "decoder.layers.24.self_attn.qkv_proj.bias", + "decoder.layers.24.self_attn.out_proj.weight", + "decoder.layers.24.self_attn.out_proj.bias", + "decoder.layers.24.self_attn_layer_norm.weight", + "decoder.layers.24.self_attn_layer_norm.bias", + "decoder.layers.24.fc1.weight", + "decoder.layers.24.fc1.bias", + "decoder.layers.24.fc2.weight", + "decoder.layers.24.fc2.bias", + "decoder.layers.24.final_layer_norm.weight", + "decoder.layers.24.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.25.flat_param_0": { + "names": [ + "decoder.layers.25.self_attn.qkv_proj.weight", + "decoder.layers.25.self_attn.qkv_proj.bias", + "decoder.layers.25.self_attn.out_proj.weight", + "decoder.layers.25.self_attn.out_proj.bias", + "decoder.layers.25.self_attn_layer_norm.weight", + "decoder.layers.25.self_attn_layer_norm.bias", + "decoder.layers.25.fc1.weight", + "decoder.layers.25.fc1.bias", + "decoder.layers.25.fc2.weight", + "decoder.layers.25.fc2.bias", + "decoder.layers.25.final_layer_norm.weight", + "decoder.layers.25.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.26.flat_param_0": { + "names": [ + "decoder.layers.26.self_attn.qkv_proj.weight", + "decoder.layers.26.self_attn.qkv_proj.bias", + "decoder.layers.26.self_attn.out_proj.weight", + "decoder.layers.26.self_attn.out_proj.bias", + "decoder.layers.26.self_attn_layer_norm.weight", + "decoder.layers.26.self_attn_layer_norm.bias", + "decoder.layers.26.fc1.weight", + "decoder.layers.26.fc1.bias", + "decoder.layers.26.fc2.weight", + "decoder.layers.26.fc2.bias", + "decoder.layers.26.final_layer_norm.weight", + "decoder.layers.26.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.27.flat_param_0": { + "names": [ + "decoder.layers.27.self_attn.qkv_proj.weight", + "decoder.layers.27.self_attn.qkv_proj.bias", + "decoder.layers.27.self_attn.out_proj.weight", + "decoder.layers.27.self_attn.out_proj.bias", + "decoder.layers.27.self_attn_layer_norm.weight", + "decoder.layers.27.self_attn_layer_norm.bias", + "decoder.layers.27.fc1.weight", + "decoder.layers.27.fc1.bias", + "decoder.layers.27.fc2.weight", + "decoder.layers.27.fc2.bias", + "decoder.layers.27.final_layer_norm.weight", + "decoder.layers.27.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.28.flat_param_0": { + "names": [ + "decoder.layers.28.self_attn.qkv_proj.weight", + "decoder.layers.28.self_attn.qkv_proj.bias", + "decoder.layers.28.self_attn.out_proj.weight", + "decoder.layers.28.self_attn.out_proj.bias", + "decoder.layers.28.self_attn_layer_norm.weight", + "decoder.layers.28.self_attn_layer_norm.bias", + "decoder.layers.28.fc1.weight", + "decoder.layers.28.fc1.bias", + "decoder.layers.28.fc2.weight", + "decoder.layers.28.fc2.bias", + "decoder.layers.28.final_layer_norm.weight", + "decoder.layers.28.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.29.flat_param_0": { + "names": [ + "decoder.layers.29.self_attn.qkv_proj.weight", + "decoder.layers.29.self_attn.qkv_proj.bias", + "decoder.layers.29.self_attn.out_proj.weight", + "decoder.layers.29.self_attn.out_proj.bias", + "decoder.layers.29.self_attn_layer_norm.weight", + "decoder.layers.29.self_attn_layer_norm.bias", + "decoder.layers.29.fc1.weight", + "decoder.layers.29.fc1.bias", + "decoder.layers.29.fc2.weight", + "decoder.layers.29.fc2.bias", + "decoder.layers.29.final_layer_norm.weight", + "decoder.layers.29.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.30.flat_param_0": { + "names": [ + "decoder.layers.30.self_attn.qkv_proj.weight", + "decoder.layers.30.self_attn.qkv_proj.bias", + "decoder.layers.30.self_attn.out_proj.weight", + "decoder.layers.30.self_attn.out_proj.bias", + "decoder.layers.30.self_attn_layer_norm.weight", + "decoder.layers.30.self_attn_layer_norm.bias", + "decoder.layers.30.fc1.weight", + "decoder.layers.30.fc1.bias", + "decoder.layers.30.fc2.weight", + "decoder.layers.30.fc2.bias", + "decoder.layers.30.final_layer_norm.weight", + "decoder.layers.30.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.31.flat_param_0": { + "names": [ + "decoder.layers.31.self_attn.qkv_proj.weight", + "decoder.layers.31.self_attn.qkv_proj.bias", + "decoder.layers.31.self_attn.out_proj.weight", + "decoder.layers.31.self_attn.out_proj.bias", + "decoder.layers.31.self_attn_layer_norm.weight", + "decoder.layers.31.self_attn_layer_norm.bias", + "decoder.layers.31.fc1.weight", + "decoder.layers.31.fc1.bias", + "decoder.layers.31.fc2.weight", + "decoder.layers.31.fc2.bias", + "decoder.layers.31.final_layer_norm.weight", + "decoder.layers.31.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.32.flat_param_0": { + "names": [ + "decoder.layers.32.self_attn.qkv_proj.weight", + "decoder.layers.32.self_attn.qkv_proj.bias", + "decoder.layers.32.self_attn.out_proj.weight", + "decoder.layers.32.self_attn.out_proj.bias", + "decoder.layers.32.self_attn_layer_norm.weight", + "decoder.layers.32.self_attn_layer_norm.bias", + "decoder.layers.32.fc1.weight", + "decoder.layers.32.fc1.bias", + "decoder.layers.32.fc2.weight", + "decoder.layers.32.fc2.bias", + "decoder.layers.32.final_layer_norm.weight", + "decoder.layers.32.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.33.flat_param_0": { + "names": [ + "decoder.layers.33.self_attn.qkv_proj.weight", + "decoder.layers.33.self_attn.qkv_proj.bias", + "decoder.layers.33.self_attn.out_proj.weight", + "decoder.layers.33.self_attn.out_proj.bias", + "decoder.layers.33.self_attn_layer_norm.weight", + "decoder.layers.33.self_attn_layer_norm.bias", + "decoder.layers.33.fc1.weight", + "decoder.layers.33.fc1.bias", + "decoder.layers.33.fc2.weight", + "decoder.layers.33.fc2.bias", + "decoder.layers.33.final_layer_norm.weight", + "decoder.layers.33.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.34.flat_param_0": { + "names": [ + "decoder.layers.34.self_attn.qkv_proj.weight", + "decoder.layers.34.self_attn.qkv_proj.bias", + "decoder.layers.34.self_attn.out_proj.weight", + "decoder.layers.34.self_attn.out_proj.bias", + "decoder.layers.34.self_attn_layer_norm.weight", + "decoder.layers.34.self_attn_layer_norm.bias", + "decoder.layers.34.fc1.weight", + "decoder.layers.34.fc1.bias", + "decoder.layers.34.fc2.weight", + "decoder.layers.34.fc2.bias", + "decoder.layers.34.final_layer_norm.weight", + "decoder.layers.34.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.35.flat_param_0": { + "names": [ + "decoder.layers.35.self_attn.qkv_proj.weight", + "decoder.layers.35.self_attn.qkv_proj.bias", + "decoder.layers.35.self_attn.out_proj.weight", + "decoder.layers.35.self_attn.out_proj.bias", + "decoder.layers.35.self_attn_layer_norm.weight", + "decoder.layers.35.self_attn_layer_norm.bias", + "decoder.layers.35.fc1.weight", + "decoder.layers.35.fc1.bias", + "decoder.layers.35.fc2.weight", + "decoder.layers.35.fc2.bias", + "decoder.layers.35.final_layer_norm.weight", + "decoder.layers.35.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.36.flat_param_0": { + "names": [ + "decoder.layers.36.self_attn.qkv_proj.weight", + "decoder.layers.36.self_attn.qkv_proj.bias", + "decoder.layers.36.self_attn.out_proj.weight", + "decoder.layers.36.self_attn.out_proj.bias", + "decoder.layers.36.self_attn_layer_norm.weight", + "decoder.layers.36.self_attn_layer_norm.bias", + "decoder.layers.36.fc1.weight", + "decoder.layers.36.fc1.bias", + "decoder.layers.36.fc2.weight", + "decoder.layers.36.fc2.bias", + "decoder.layers.36.final_layer_norm.weight", + "decoder.layers.36.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.37.flat_param_0": { + "names": [ + "decoder.layers.37.self_attn.qkv_proj.weight", + "decoder.layers.37.self_attn.qkv_proj.bias", + "decoder.layers.37.self_attn.out_proj.weight", + "decoder.layers.37.self_attn.out_proj.bias", + "decoder.layers.37.self_attn_layer_norm.weight", + "decoder.layers.37.self_attn_layer_norm.bias", + "decoder.layers.37.fc1.weight", + "decoder.layers.37.fc1.bias", + "decoder.layers.37.fc2.weight", + "decoder.layers.37.fc2.bias", + "decoder.layers.37.final_layer_norm.weight", + "decoder.layers.37.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.38.flat_param_0": { + "names": [ + "decoder.layers.38.self_attn.qkv_proj.weight", + "decoder.layers.38.self_attn.qkv_proj.bias", + "decoder.layers.38.self_attn.out_proj.weight", + "decoder.layers.38.self_attn.out_proj.bias", + "decoder.layers.38.self_attn_layer_norm.weight", + "decoder.layers.38.self_attn_layer_norm.bias", + "decoder.layers.38.fc1.weight", + "decoder.layers.38.fc1.bias", + "decoder.layers.38.fc2.weight", + "decoder.layers.38.fc2.bias", + "decoder.layers.38.final_layer_norm.weight", + "decoder.layers.38.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.39.flat_param_0": { + "names": [ + "decoder.layers.39.self_attn.qkv_proj.weight", + "decoder.layers.39.self_attn.qkv_proj.bias", + "decoder.layers.39.self_attn.out_proj.weight", + "decoder.layers.39.self_attn.out_proj.bias", + "decoder.layers.39.self_attn_layer_norm.weight", + "decoder.layers.39.self_attn_layer_norm.bias", + "decoder.layers.39.fc1.weight", + "decoder.layers.39.fc1.bias", + "decoder.layers.39.fc2.weight", + "decoder.layers.39.fc2.bias", + "decoder.layers.39.final_layer_norm.weight", + "decoder.layers.39.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.40.flat_param_0": { + "names": [ + "decoder.layers.40.self_attn.qkv_proj.weight", + "decoder.layers.40.self_attn.qkv_proj.bias", + "decoder.layers.40.self_attn.out_proj.weight", + "decoder.layers.40.self_attn.out_proj.bias", + "decoder.layers.40.self_attn_layer_norm.weight", + "decoder.layers.40.self_attn_layer_norm.bias", + "decoder.layers.40.fc1.weight", + "decoder.layers.40.fc1.bias", + "decoder.layers.40.fc2.weight", + "decoder.layers.40.fc2.bias", + "decoder.layers.40.final_layer_norm.weight", + "decoder.layers.40.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.41.flat_param_0": { + "names": [ + "decoder.layers.41.self_attn.qkv_proj.weight", + "decoder.layers.41.self_attn.qkv_proj.bias", + "decoder.layers.41.self_attn.out_proj.weight", + "decoder.layers.41.self_attn.out_proj.bias", + "decoder.layers.41.self_attn_layer_norm.weight", + "decoder.layers.41.self_attn_layer_norm.bias", + "decoder.layers.41.fc1.weight", + "decoder.layers.41.fc1.bias", + "decoder.layers.41.fc2.weight", + "decoder.layers.41.fc2.bias", + "decoder.layers.41.final_layer_norm.weight", + "decoder.layers.41.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.42.flat_param_0": { + "names": [ + "decoder.layers.42.self_attn.qkv_proj.weight", + "decoder.layers.42.self_attn.qkv_proj.bias", + "decoder.layers.42.self_attn.out_proj.weight", + "decoder.layers.42.self_attn.out_proj.bias", + "decoder.layers.42.self_attn_layer_norm.weight", + "decoder.layers.42.self_attn_layer_norm.bias", + "decoder.layers.42.fc1.weight", + "decoder.layers.42.fc1.bias", + "decoder.layers.42.fc2.weight", + "decoder.layers.42.fc2.bias", + "decoder.layers.42.final_layer_norm.weight", + "decoder.layers.42.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.43.flat_param_0": { + "names": [ + "decoder.layers.43.self_attn.qkv_proj.weight", + "decoder.layers.43.self_attn.qkv_proj.bias", + "decoder.layers.43.self_attn.out_proj.weight", + "decoder.layers.43.self_attn.out_proj.bias", + "decoder.layers.43.self_attn_layer_norm.weight", + "decoder.layers.43.self_attn_layer_norm.bias", + "decoder.layers.43.fc1.weight", + "decoder.layers.43.fc1.bias", + "decoder.layers.43.fc2.weight", + "decoder.layers.43.fc2.bias", + "decoder.layers.43.final_layer_norm.weight", + "decoder.layers.43.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.44.flat_param_0": { + "names": [ + "decoder.layers.44.self_attn.qkv_proj.weight", + "decoder.layers.44.self_attn.qkv_proj.bias", + "decoder.layers.44.self_attn.out_proj.weight", + "decoder.layers.44.self_attn.out_proj.bias", + "decoder.layers.44.self_attn_layer_norm.weight", + "decoder.layers.44.self_attn_layer_norm.bias", + "decoder.layers.44.fc1.weight", + "decoder.layers.44.fc1.bias", + "decoder.layers.44.fc2.weight", + "decoder.layers.44.fc2.bias", + "decoder.layers.44.final_layer_norm.weight", + "decoder.layers.44.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.45.flat_param_0": { + "names": [ + "decoder.layers.45.self_attn.qkv_proj.weight", + "decoder.layers.45.self_attn.qkv_proj.bias", + "decoder.layers.45.self_attn.out_proj.weight", + "decoder.layers.45.self_attn.out_proj.bias", + "decoder.layers.45.self_attn_layer_norm.weight", + "decoder.layers.45.self_attn_layer_norm.bias", + "decoder.layers.45.fc1.weight", + "decoder.layers.45.fc1.bias", + "decoder.layers.45.fc2.weight", + "decoder.layers.45.fc2.bias", + "decoder.layers.45.final_layer_norm.weight", + "decoder.layers.45.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.46.flat_param_0": { + "names": [ + "decoder.layers.46.self_attn.qkv_proj.weight", + "decoder.layers.46.self_attn.qkv_proj.bias", + "decoder.layers.46.self_attn.out_proj.weight", + "decoder.layers.46.self_attn.out_proj.bias", + "decoder.layers.46.self_attn_layer_norm.weight", + "decoder.layers.46.self_attn_layer_norm.bias", + "decoder.layers.46.fc1.weight", + "decoder.layers.46.fc1.bias", + "decoder.layers.46.fc2.weight", + "decoder.layers.46.fc2.bias", + "decoder.layers.46.final_layer_norm.weight", + "decoder.layers.46.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.47.flat_param_0": { + "names": [ + "decoder.layers.47.self_attn.qkv_proj.weight", + "decoder.layers.47.self_attn.qkv_proj.bias", + "decoder.layers.47.self_attn.out_proj.weight", + "decoder.layers.47.self_attn.out_proj.bias", + "decoder.layers.47.self_attn_layer_norm.weight", + "decoder.layers.47.self_attn_layer_norm.bias", + "decoder.layers.47.fc1.weight", + "decoder.layers.47.fc1.bias", + "decoder.layers.47.fc2.weight", + "decoder.layers.47.fc2.bias", + "decoder.layers.47.final_layer_norm.weight", + "decoder.layers.47.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.48.flat_param_0": { + "names": [ + "decoder.layers.48.self_attn.qkv_proj.weight", + "decoder.layers.48.self_attn.qkv_proj.bias", + "decoder.layers.48.self_attn.out_proj.weight", + "decoder.layers.48.self_attn.out_proj.bias", + "decoder.layers.48.self_attn_layer_norm.weight", + "decoder.layers.48.self_attn_layer_norm.bias", + "decoder.layers.48.fc1.weight", + "decoder.layers.48.fc1.bias", + "decoder.layers.48.fc2.weight", + "decoder.layers.48.fc2.bias", + "decoder.layers.48.final_layer_norm.weight", + "decoder.layers.48.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.49.flat_param_0": { + "names": [ + "decoder.layers.49.self_attn.qkv_proj.weight", + "decoder.layers.49.self_attn.qkv_proj.bias", + "decoder.layers.49.self_attn.out_proj.weight", + "decoder.layers.49.self_attn.out_proj.bias", + "decoder.layers.49.self_attn_layer_norm.weight", + "decoder.layers.49.self_attn_layer_norm.bias", + "decoder.layers.49.fc1.weight", + "decoder.layers.49.fc1.bias", + "decoder.layers.49.fc2.weight", + "decoder.layers.49.fc2.bias", + "decoder.layers.49.final_layer_norm.weight", + "decoder.layers.49.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.50.flat_param_0": { + "names": [ + "decoder.layers.50.self_attn.qkv_proj.weight", + "decoder.layers.50.self_attn.qkv_proj.bias", + "decoder.layers.50.self_attn.out_proj.weight", + "decoder.layers.50.self_attn.out_proj.bias", + "decoder.layers.50.self_attn_layer_norm.weight", + "decoder.layers.50.self_attn_layer_norm.bias", + "decoder.layers.50.fc1.weight", + "decoder.layers.50.fc1.bias", + "decoder.layers.50.fc2.weight", + "decoder.layers.50.fc2.bias", + "decoder.layers.50.final_layer_norm.weight", + "decoder.layers.50.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.51.flat_param_0": { + "names": [ + "decoder.layers.51.self_attn.qkv_proj.weight", + "decoder.layers.51.self_attn.qkv_proj.bias", + "decoder.layers.51.self_attn.out_proj.weight", + "decoder.layers.51.self_attn.out_proj.bias", + "decoder.layers.51.self_attn_layer_norm.weight", + "decoder.layers.51.self_attn_layer_norm.bias", + "decoder.layers.51.fc1.weight", + "decoder.layers.51.fc1.bias", + "decoder.layers.51.fc2.weight", + "decoder.layers.51.fc2.bias", + "decoder.layers.51.final_layer_norm.weight", + "decoder.layers.51.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.52.flat_param_0": { + "names": [ + "decoder.layers.52.self_attn.qkv_proj.weight", + "decoder.layers.52.self_attn.qkv_proj.bias", + "decoder.layers.52.self_attn.out_proj.weight", + "decoder.layers.52.self_attn.out_proj.bias", + "decoder.layers.52.self_attn_layer_norm.weight", + "decoder.layers.52.self_attn_layer_norm.bias", + "decoder.layers.52.fc1.weight", + "decoder.layers.52.fc1.bias", + "decoder.layers.52.fc2.weight", + "decoder.layers.52.fc2.bias", + "decoder.layers.52.final_layer_norm.weight", + "decoder.layers.52.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.53.flat_param_0": { + "names": [ + "decoder.layers.53.self_attn.qkv_proj.weight", + "decoder.layers.53.self_attn.qkv_proj.bias", + "decoder.layers.53.self_attn.out_proj.weight", + "decoder.layers.53.self_attn.out_proj.bias", + "decoder.layers.53.self_attn_layer_norm.weight", + "decoder.layers.53.self_attn_layer_norm.bias", + "decoder.layers.53.fc1.weight", + "decoder.layers.53.fc1.bias", + "decoder.layers.53.fc2.weight", + "decoder.layers.53.fc2.bias", + "decoder.layers.53.final_layer_norm.weight", + "decoder.layers.53.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.54.flat_param_0": { + "names": [ + "decoder.layers.54.self_attn.qkv_proj.weight", + "decoder.layers.54.self_attn.qkv_proj.bias", + "decoder.layers.54.self_attn.out_proj.weight", + "decoder.layers.54.self_attn.out_proj.bias", + "decoder.layers.54.self_attn_layer_norm.weight", + "decoder.layers.54.self_attn_layer_norm.bias", + "decoder.layers.54.fc1.weight", + "decoder.layers.54.fc1.bias", + "decoder.layers.54.fc2.weight", + "decoder.layers.54.fc2.bias", + "decoder.layers.54.final_layer_norm.weight", + "decoder.layers.54.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.55.flat_param_0": { + "names": [ + "decoder.layers.55.self_attn.qkv_proj.weight", + "decoder.layers.55.self_attn.qkv_proj.bias", + "decoder.layers.55.self_attn.out_proj.weight", + "decoder.layers.55.self_attn.out_proj.bias", + "decoder.layers.55.self_attn_layer_norm.weight", + "decoder.layers.55.self_attn_layer_norm.bias", + "decoder.layers.55.fc1.weight", + "decoder.layers.55.fc1.bias", + "decoder.layers.55.fc2.weight", + "decoder.layers.55.fc2.bias", + "decoder.layers.55.final_layer_norm.weight", + "decoder.layers.55.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.56.flat_param_0": { + "names": [ + "decoder.layers.56.self_attn.qkv_proj.weight", + "decoder.layers.56.self_attn.qkv_proj.bias", + "decoder.layers.56.self_attn.out_proj.weight", + "decoder.layers.56.self_attn.out_proj.bias", + "decoder.layers.56.self_attn_layer_norm.weight", + "decoder.layers.56.self_attn_layer_norm.bias", + "decoder.layers.56.fc1.weight", + "decoder.layers.56.fc1.bias", + "decoder.layers.56.fc2.weight", + "decoder.layers.56.fc2.bias", + "decoder.layers.56.final_layer_norm.weight", + "decoder.layers.56.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.57.flat_param_0": { + "names": [ + "decoder.layers.57.self_attn.qkv_proj.weight", + "decoder.layers.57.self_attn.qkv_proj.bias", + "decoder.layers.57.self_attn.out_proj.weight", + "decoder.layers.57.self_attn.out_proj.bias", + "decoder.layers.57.self_attn_layer_norm.weight", + "decoder.layers.57.self_attn_layer_norm.bias", + "decoder.layers.57.fc1.weight", + "decoder.layers.57.fc1.bias", + "decoder.layers.57.fc2.weight", + "decoder.layers.57.fc2.bias", + "decoder.layers.57.final_layer_norm.weight", + "decoder.layers.57.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.58.flat_param_0": { + "names": [ + "decoder.layers.58.self_attn.qkv_proj.weight", + "decoder.layers.58.self_attn.qkv_proj.bias", + "decoder.layers.58.self_attn.out_proj.weight", + "decoder.layers.58.self_attn.out_proj.bias", + "decoder.layers.58.self_attn_layer_norm.weight", + "decoder.layers.58.self_attn_layer_norm.bias", + "decoder.layers.58.fc1.weight", + "decoder.layers.58.fc1.bias", + "decoder.layers.58.fc2.weight", + "decoder.layers.58.fc2.bias", + "decoder.layers.58.final_layer_norm.weight", + "decoder.layers.58.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.59.flat_param_0": { + "names": [ + "decoder.layers.59.self_attn.qkv_proj.weight", + "decoder.layers.59.self_attn.qkv_proj.bias", + "decoder.layers.59.self_attn.out_proj.weight", + "decoder.layers.59.self_attn.out_proj.bias", + "decoder.layers.59.self_attn_layer_norm.weight", + "decoder.layers.59.self_attn_layer_norm.bias", + "decoder.layers.59.fc1.weight", + "decoder.layers.59.fc1.bias", + "decoder.layers.59.fc2.weight", + "decoder.layers.59.fc2.bias", + "decoder.layers.59.final_layer_norm.weight", + "decoder.layers.59.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.60.flat_param_0": { + "names": [ + "decoder.layers.60.self_attn.qkv_proj.weight", + "decoder.layers.60.self_attn.qkv_proj.bias", + "decoder.layers.60.self_attn.out_proj.weight", + "decoder.layers.60.self_attn.out_proj.bias", + "decoder.layers.60.self_attn_layer_norm.weight", + "decoder.layers.60.self_attn_layer_norm.bias", + "decoder.layers.60.fc1.weight", + "decoder.layers.60.fc1.bias", + "decoder.layers.60.fc2.weight", + "decoder.layers.60.fc2.bias", + "decoder.layers.60.final_layer_norm.weight", + "decoder.layers.60.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.61.flat_param_0": { + "names": [ + "decoder.layers.61.self_attn.qkv_proj.weight", + "decoder.layers.61.self_attn.qkv_proj.bias", + "decoder.layers.61.self_attn.out_proj.weight", + "decoder.layers.61.self_attn.out_proj.bias", + "decoder.layers.61.self_attn_layer_norm.weight", + "decoder.layers.61.self_attn_layer_norm.bias", + "decoder.layers.61.fc1.weight", + "decoder.layers.61.fc1.bias", + "decoder.layers.61.fc2.weight", + "decoder.layers.61.fc2.bias", + "decoder.layers.61.final_layer_norm.weight", + "decoder.layers.61.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.62.flat_param_0": { + "names": [ + "decoder.layers.62.self_attn.qkv_proj.weight", + "decoder.layers.62.self_attn.qkv_proj.bias", + "decoder.layers.62.self_attn.out_proj.weight", + "decoder.layers.62.self_attn.out_proj.bias", + "decoder.layers.62.self_attn_layer_norm.weight", + "decoder.layers.62.self_attn_layer_norm.bias", + "decoder.layers.62.fc1.weight", + "decoder.layers.62.fc1.bias", + "decoder.layers.62.fc2.weight", + "decoder.layers.62.fc2.bias", + "decoder.layers.62.final_layer_norm.weight", + "decoder.layers.62.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.63.flat_param_0": { + "names": [ + "decoder.layers.63.self_attn.qkv_proj.weight", + "decoder.layers.63.self_attn.qkv_proj.bias", + "decoder.layers.63.self_attn.out_proj.weight", + "decoder.layers.63.self_attn.out_proj.bias", + "decoder.layers.63.self_attn_layer_norm.weight", + "decoder.layers.63.self_attn_layer_norm.bias", + "decoder.layers.63.fc1.weight", + "decoder.layers.63.fc1.bias", + "decoder.layers.63.fc2.weight", + "decoder.layers.63.fc2.bias", + "decoder.layers.63.final_layer_norm.weight", + "decoder.layers.63.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.64.flat_param_0": { + "names": [ + "decoder.layers.64.self_attn.qkv_proj.weight", + "decoder.layers.64.self_attn.qkv_proj.bias", + "decoder.layers.64.self_attn.out_proj.weight", + "decoder.layers.64.self_attn.out_proj.bias", + "decoder.layers.64.self_attn_layer_norm.weight", + "decoder.layers.64.self_attn_layer_norm.bias", + "decoder.layers.64.fc1.weight", + "decoder.layers.64.fc1.bias", + "decoder.layers.64.fc2.weight", + "decoder.layers.64.fc2.bias", + "decoder.layers.64.final_layer_norm.weight", + "decoder.layers.64.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.65.flat_param_0": { + "names": [ + "decoder.layers.65.self_attn.qkv_proj.weight", + "decoder.layers.65.self_attn.qkv_proj.bias", + "decoder.layers.65.self_attn.out_proj.weight", + "decoder.layers.65.self_attn.out_proj.bias", + "decoder.layers.65.self_attn_layer_norm.weight", + "decoder.layers.65.self_attn_layer_norm.bias", + "decoder.layers.65.fc1.weight", + "decoder.layers.65.fc1.bias", + "decoder.layers.65.fc2.weight", + "decoder.layers.65.fc2.bias", + "decoder.layers.65.final_layer_norm.weight", + "decoder.layers.65.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.66.flat_param_0": { + "names": [ + "decoder.layers.66.self_attn.qkv_proj.weight", + "decoder.layers.66.self_attn.qkv_proj.bias", + "decoder.layers.66.self_attn.out_proj.weight", + "decoder.layers.66.self_attn.out_proj.bias", + "decoder.layers.66.self_attn_layer_norm.weight", + "decoder.layers.66.self_attn_layer_norm.bias", + "decoder.layers.66.fc1.weight", + "decoder.layers.66.fc1.bias", + "decoder.layers.66.fc2.weight", + "decoder.layers.66.fc2.bias", + "decoder.layers.66.final_layer_norm.weight", + "decoder.layers.66.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.67.flat_param_0": { + "names": [ + "decoder.layers.67.self_attn.qkv_proj.weight", + "decoder.layers.67.self_attn.qkv_proj.bias", + "decoder.layers.67.self_attn.out_proj.weight", + "decoder.layers.67.self_attn.out_proj.bias", + "decoder.layers.67.self_attn_layer_norm.weight", + "decoder.layers.67.self_attn_layer_norm.bias", + "decoder.layers.67.fc1.weight", + "decoder.layers.67.fc1.bias", + "decoder.layers.67.fc2.weight", + "decoder.layers.67.fc2.bias", + "decoder.layers.67.final_layer_norm.weight", + "decoder.layers.67.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.68.flat_param_0": { + "names": [ + "decoder.layers.68.self_attn.qkv_proj.weight", + "decoder.layers.68.self_attn.qkv_proj.bias", + "decoder.layers.68.self_attn.out_proj.weight", + "decoder.layers.68.self_attn.out_proj.bias", + "decoder.layers.68.self_attn_layer_norm.weight", + "decoder.layers.68.self_attn_layer_norm.bias", + "decoder.layers.68.fc1.weight", + "decoder.layers.68.fc1.bias", + "decoder.layers.68.fc2.weight", + "decoder.layers.68.fc2.bias", + "decoder.layers.68.final_layer_norm.weight", + "decoder.layers.68.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.69.flat_param_0": { + "names": [ + "decoder.layers.69.self_attn.qkv_proj.weight", + "decoder.layers.69.self_attn.qkv_proj.bias", + "decoder.layers.69.self_attn.out_proj.weight", + "decoder.layers.69.self_attn.out_proj.bias", + "decoder.layers.69.self_attn_layer_norm.weight", + "decoder.layers.69.self_attn_layer_norm.bias", + "decoder.layers.69.fc1.weight", + "decoder.layers.69.fc1.bias", + "decoder.layers.69.fc2.weight", + "decoder.layers.69.fc2.bias", + "decoder.layers.69.final_layer_norm.weight", + "decoder.layers.69.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.70.flat_param_0": { + "names": [ + "decoder.layers.70.self_attn.qkv_proj.weight", + "decoder.layers.70.self_attn.qkv_proj.bias", + "decoder.layers.70.self_attn.out_proj.weight", + "decoder.layers.70.self_attn.out_proj.bias", + "decoder.layers.70.self_attn_layer_norm.weight", + "decoder.layers.70.self_attn_layer_norm.bias", + "decoder.layers.70.fc1.weight", + "decoder.layers.70.fc1.bias", + "decoder.layers.70.fc2.weight", + "decoder.layers.70.fc2.bias", + "decoder.layers.70.final_layer_norm.weight", + "decoder.layers.70.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.71.flat_param_0": { + "names": [ + "decoder.layers.71.self_attn.qkv_proj.weight", + "decoder.layers.71.self_attn.qkv_proj.bias", + "decoder.layers.71.self_attn.out_proj.weight", + "decoder.layers.71.self_attn.out_proj.bias", + "decoder.layers.71.self_attn_layer_norm.weight", + "decoder.layers.71.self_attn_layer_norm.bias", + "decoder.layers.71.fc1.weight", + "decoder.layers.71.fc1.bias", + "decoder.layers.71.fc2.weight", + "decoder.layers.71.fc2.bias", + "decoder.layers.71.final_layer_norm.weight", + "decoder.layers.71.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.72.flat_param_0": { + "names": [ + "decoder.layers.72.self_attn.qkv_proj.weight", + "decoder.layers.72.self_attn.qkv_proj.bias", + "decoder.layers.72.self_attn.out_proj.weight", + "decoder.layers.72.self_attn.out_proj.bias", + "decoder.layers.72.self_attn_layer_norm.weight", + "decoder.layers.72.self_attn_layer_norm.bias", + "decoder.layers.72.fc1.weight", + "decoder.layers.72.fc1.bias", + "decoder.layers.72.fc2.weight", + "decoder.layers.72.fc2.bias", + "decoder.layers.72.final_layer_norm.weight", + "decoder.layers.72.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.73.flat_param_0": { + "names": [ + "decoder.layers.73.self_attn.qkv_proj.weight", + "decoder.layers.73.self_attn.qkv_proj.bias", + "decoder.layers.73.self_attn.out_proj.weight", + "decoder.layers.73.self_attn.out_proj.bias", + "decoder.layers.73.self_attn_layer_norm.weight", + "decoder.layers.73.self_attn_layer_norm.bias", + "decoder.layers.73.fc1.weight", + "decoder.layers.73.fc1.bias", + "decoder.layers.73.fc2.weight", + "decoder.layers.73.fc2.bias", + "decoder.layers.73.final_layer_norm.weight", + "decoder.layers.73.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.74.flat_param_0": { + "names": [ + "decoder.layers.74.self_attn.qkv_proj.weight", + "decoder.layers.74.self_attn.qkv_proj.bias", + "decoder.layers.74.self_attn.out_proj.weight", + "decoder.layers.74.self_attn.out_proj.bias", + "decoder.layers.74.self_attn_layer_norm.weight", + "decoder.layers.74.self_attn_layer_norm.bias", + "decoder.layers.74.fc1.weight", + "decoder.layers.74.fc1.bias", + "decoder.layers.74.fc2.weight", + "decoder.layers.74.fc2.bias", + "decoder.layers.74.final_layer_norm.weight", + "decoder.layers.74.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.75.flat_param_0": { + "names": [ + "decoder.layers.75.self_attn.qkv_proj.weight", + "decoder.layers.75.self_attn.qkv_proj.bias", + "decoder.layers.75.self_attn.out_proj.weight", + "decoder.layers.75.self_attn.out_proj.bias", + "decoder.layers.75.self_attn_layer_norm.weight", + "decoder.layers.75.self_attn_layer_norm.bias", + "decoder.layers.75.fc1.weight", + "decoder.layers.75.fc1.bias", + "decoder.layers.75.fc2.weight", + "decoder.layers.75.fc2.bias", + "decoder.layers.75.final_layer_norm.weight", + "decoder.layers.75.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.76.flat_param_0": { + "names": [ + "decoder.layers.76.self_attn.qkv_proj.weight", + "decoder.layers.76.self_attn.qkv_proj.bias", + "decoder.layers.76.self_attn.out_proj.weight", + "decoder.layers.76.self_attn.out_proj.bias", + "decoder.layers.76.self_attn_layer_norm.weight", + "decoder.layers.76.self_attn_layer_norm.bias", + "decoder.layers.76.fc1.weight", + "decoder.layers.76.fc1.bias", + "decoder.layers.76.fc2.weight", + "decoder.layers.76.fc2.bias", + "decoder.layers.76.final_layer_norm.weight", + "decoder.layers.76.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.77.flat_param_0": { + "names": [ + "decoder.layers.77.self_attn.qkv_proj.weight", + "decoder.layers.77.self_attn.qkv_proj.bias", + "decoder.layers.77.self_attn.out_proj.weight", + "decoder.layers.77.self_attn.out_proj.bias", + "decoder.layers.77.self_attn_layer_norm.weight", + "decoder.layers.77.self_attn_layer_norm.bias", + "decoder.layers.77.fc1.weight", + "decoder.layers.77.fc1.bias", + "decoder.layers.77.fc2.weight", + "decoder.layers.77.fc2.bias", + "decoder.layers.77.final_layer_norm.weight", + "decoder.layers.77.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.78.flat_param_0": { + "names": [ + "decoder.layers.78.self_attn.qkv_proj.weight", + "decoder.layers.78.self_attn.qkv_proj.bias", + "decoder.layers.78.self_attn.out_proj.weight", + "decoder.layers.78.self_attn.out_proj.bias", + "decoder.layers.78.self_attn_layer_norm.weight", + "decoder.layers.78.self_attn_layer_norm.bias", + "decoder.layers.78.fc1.weight", + "decoder.layers.78.fc1.bias", + "decoder.layers.78.fc2.weight", + "decoder.layers.78.fc2.bias", + "decoder.layers.78.final_layer_norm.weight", + "decoder.layers.78.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.79.flat_param_0": { + "names": [ + "decoder.layers.79.self_attn.qkv_proj.weight", + "decoder.layers.79.self_attn.qkv_proj.bias", + "decoder.layers.79.self_attn.out_proj.weight", + "decoder.layers.79.self_attn.out_proj.bias", + "decoder.layers.79.self_attn_layer_norm.weight", + "decoder.layers.79.self_attn_layer_norm.bias", + "decoder.layers.79.fc1.weight", + "decoder.layers.79.fc1.bias", + "decoder.layers.79.fc2.weight", + "decoder.layers.79.fc2.bias", + "decoder.layers.79.final_layer_norm.weight", + "decoder.layers.79.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.80.flat_param_0": { + "names": [ + "decoder.layers.80.self_attn.qkv_proj.weight", + "decoder.layers.80.self_attn.qkv_proj.bias", + "decoder.layers.80.self_attn.out_proj.weight", + "decoder.layers.80.self_attn.out_proj.bias", + "decoder.layers.80.self_attn_layer_norm.weight", + "decoder.layers.80.self_attn_layer_norm.bias", + "decoder.layers.80.fc1.weight", + "decoder.layers.80.fc1.bias", + "decoder.layers.80.fc2.weight", + "decoder.layers.80.fc2.bias", + "decoder.layers.80.final_layer_norm.weight", + "decoder.layers.80.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.81.flat_param_0": { + "names": [ + "decoder.layers.81.self_attn.qkv_proj.weight", + "decoder.layers.81.self_attn.qkv_proj.bias", + "decoder.layers.81.self_attn.out_proj.weight", + "decoder.layers.81.self_attn.out_proj.bias", + "decoder.layers.81.self_attn_layer_norm.weight", + "decoder.layers.81.self_attn_layer_norm.bias", + "decoder.layers.81.fc1.weight", + "decoder.layers.81.fc1.bias", + "decoder.layers.81.fc2.weight", + "decoder.layers.81.fc2.bias", + "decoder.layers.81.final_layer_norm.weight", + "decoder.layers.81.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.82.flat_param_0": { + "names": [ + "decoder.layers.82.self_attn.qkv_proj.weight", + "decoder.layers.82.self_attn.qkv_proj.bias", + "decoder.layers.82.self_attn.out_proj.weight", + "decoder.layers.82.self_attn.out_proj.bias", + "decoder.layers.82.self_attn_layer_norm.weight", + "decoder.layers.82.self_attn_layer_norm.bias", + "decoder.layers.82.fc1.weight", + "decoder.layers.82.fc1.bias", + "decoder.layers.82.fc2.weight", + "decoder.layers.82.fc2.bias", + "decoder.layers.82.final_layer_norm.weight", + "decoder.layers.82.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.83.flat_param_0": { + "names": [ + "decoder.layers.83.self_attn.qkv_proj.weight", + "decoder.layers.83.self_attn.qkv_proj.bias", + "decoder.layers.83.self_attn.out_proj.weight", + "decoder.layers.83.self_attn.out_proj.bias", + "decoder.layers.83.self_attn_layer_norm.weight", + "decoder.layers.83.self_attn_layer_norm.bias", + "decoder.layers.83.fc1.weight", + "decoder.layers.83.fc1.bias", + "decoder.layers.83.fc2.weight", + "decoder.layers.83.fc2.bias", + "decoder.layers.83.final_layer_norm.weight", + "decoder.layers.83.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.84.flat_param_0": { + "names": [ + "decoder.layers.84.self_attn.qkv_proj.weight", + "decoder.layers.84.self_attn.qkv_proj.bias", + "decoder.layers.84.self_attn.out_proj.weight", + "decoder.layers.84.self_attn.out_proj.bias", + "decoder.layers.84.self_attn_layer_norm.weight", + "decoder.layers.84.self_attn_layer_norm.bias", + "decoder.layers.84.fc1.weight", + "decoder.layers.84.fc1.bias", + "decoder.layers.84.fc2.weight", + "decoder.layers.84.fc2.bias", + "decoder.layers.84.final_layer_norm.weight", + "decoder.layers.84.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.85.flat_param_0": { + "names": [ + "decoder.layers.85.self_attn.qkv_proj.weight", + "decoder.layers.85.self_attn.qkv_proj.bias", + "decoder.layers.85.self_attn.out_proj.weight", + "decoder.layers.85.self_attn.out_proj.bias", + "decoder.layers.85.self_attn_layer_norm.weight", + "decoder.layers.85.self_attn_layer_norm.bias", + "decoder.layers.85.fc1.weight", + "decoder.layers.85.fc1.bias", + "decoder.layers.85.fc2.weight", + "decoder.layers.85.fc2.bias", + "decoder.layers.85.final_layer_norm.weight", + "decoder.layers.85.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.86.flat_param_0": { + "names": [ + "decoder.layers.86.self_attn.qkv_proj.weight", + "decoder.layers.86.self_attn.qkv_proj.bias", + "decoder.layers.86.self_attn.out_proj.weight", + "decoder.layers.86.self_attn.out_proj.bias", + "decoder.layers.86.self_attn_layer_norm.weight", + "decoder.layers.86.self_attn_layer_norm.bias", + "decoder.layers.86.fc1.weight", + "decoder.layers.86.fc1.bias", + "decoder.layers.86.fc2.weight", + "decoder.layers.86.fc2.bias", + "decoder.layers.86.final_layer_norm.weight", + "decoder.layers.86.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.87.flat_param_0": { + "names": [ + "decoder.layers.87.self_attn.qkv_proj.weight", + "decoder.layers.87.self_attn.qkv_proj.bias", + "decoder.layers.87.self_attn.out_proj.weight", + "decoder.layers.87.self_attn.out_proj.bias", + "decoder.layers.87.self_attn_layer_norm.weight", + "decoder.layers.87.self_attn_layer_norm.bias", + "decoder.layers.87.fc1.weight", + "decoder.layers.87.fc1.bias", + "decoder.layers.87.fc2.weight", + "decoder.layers.87.fc2.bias", + "decoder.layers.87.final_layer_norm.weight", + "decoder.layers.87.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.88.flat_param_0": { + "names": [ + "decoder.layers.88.self_attn.qkv_proj.weight", + "decoder.layers.88.self_attn.qkv_proj.bias", + "decoder.layers.88.self_attn.out_proj.weight", + "decoder.layers.88.self_attn.out_proj.bias", + "decoder.layers.88.self_attn_layer_norm.weight", + "decoder.layers.88.self_attn_layer_norm.bias", + "decoder.layers.88.fc1.weight", + "decoder.layers.88.fc1.bias", + "decoder.layers.88.fc2.weight", + "decoder.layers.88.fc2.bias", + "decoder.layers.88.final_layer_norm.weight", + "decoder.layers.88.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.89.flat_param_0": { + "names": [ + "decoder.layers.89.self_attn.qkv_proj.weight", + "decoder.layers.89.self_attn.qkv_proj.bias", + "decoder.layers.89.self_attn.out_proj.weight", + "decoder.layers.89.self_attn.out_proj.bias", + "decoder.layers.89.self_attn_layer_norm.weight", + "decoder.layers.89.self_attn_layer_norm.bias", + "decoder.layers.89.fc1.weight", + "decoder.layers.89.fc1.bias", + "decoder.layers.89.fc2.weight", + "decoder.layers.89.fc2.bias", + "decoder.layers.89.final_layer_norm.weight", + "decoder.layers.89.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.90.flat_param_0": { + "names": [ + "decoder.layers.90.self_attn.qkv_proj.weight", + "decoder.layers.90.self_attn.qkv_proj.bias", + "decoder.layers.90.self_attn.out_proj.weight", + "decoder.layers.90.self_attn.out_proj.bias", + "decoder.layers.90.self_attn_layer_norm.weight", + "decoder.layers.90.self_attn_layer_norm.bias", + "decoder.layers.90.fc1.weight", + "decoder.layers.90.fc1.bias", + "decoder.layers.90.fc2.weight", + "decoder.layers.90.fc2.bias", + "decoder.layers.90.final_layer_norm.weight", + "decoder.layers.90.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.91.flat_param_0": { + "names": [ + "decoder.layers.91.self_attn.qkv_proj.weight", + "decoder.layers.91.self_attn.qkv_proj.bias", + "decoder.layers.91.self_attn.out_proj.weight", + "decoder.layers.91.self_attn.out_proj.bias", + "decoder.layers.91.self_attn_layer_norm.weight", + "decoder.layers.91.self_attn_layer_norm.bias", + "decoder.layers.91.fc1.weight", + "decoder.layers.91.fc1.bias", + "decoder.layers.91.fc2.weight", + "decoder.layers.91.fc2.bias", + "decoder.layers.91.final_layer_norm.weight", + "decoder.layers.91.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.92.flat_param_0": { + "names": [ + "decoder.layers.92.self_attn.qkv_proj.weight", + "decoder.layers.92.self_attn.qkv_proj.bias", + "decoder.layers.92.self_attn.out_proj.weight", + "decoder.layers.92.self_attn.out_proj.bias", + "decoder.layers.92.self_attn_layer_norm.weight", + "decoder.layers.92.self_attn_layer_norm.bias", + "decoder.layers.92.fc1.weight", + "decoder.layers.92.fc1.bias", + "decoder.layers.92.fc2.weight", + "decoder.layers.92.fc2.bias", + "decoder.layers.92.final_layer_norm.weight", + "decoder.layers.92.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.93.flat_param_0": { + "names": [ + "decoder.layers.93.self_attn.qkv_proj.weight", + "decoder.layers.93.self_attn.qkv_proj.bias", + "decoder.layers.93.self_attn.out_proj.weight", + "decoder.layers.93.self_attn.out_proj.bias", + "decoder.layers.93.self_attn_layer_norm.weight", + "decoder.layers.93.self_attn_layer_norm.bias", + "decoder.layers.93.fc1.weight", + "decoder.layers.93.fc1.bias", + "decoder.layers.93.fc2.weight", + "decoder.layers.93.fc2.bias", + "decoder.layers.93.final_layer_norm.weight", + "decoder.layers.93.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.94.flat_param_0": { + "names": [ + "decoder.layers.94.self_attn.qkv_proj.weight", + "decoder.layers.94.self_attn.qkv_proj.bias", + "decoder.layers.94.self_attn.out_proj.weight", + "decoder.layers.94.self_attn.out_proj.bias", + "decoder.layers.94.self_attn_layer_norm.weight", + "decoder.layers.94.self_attn_layer_norm.bias", + "decoder.layers.94.fc1.weight", + "decoder.layers.94.fc1.bias", + "decoder.layers.94.fc2.weight", + "decoder.layers.94.fc2.bias", + "decoder.layers.94.final_layer_norm.weight", + "decoder.layers.94.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.95.flat_param_0": { + "names": [ + "decoder.layers.95.self_attn.qkv_proj.weight", + "decoder.layers.95.self_attn.qkv_proj.bias", + "decoder.layers.95.self_attn.out_proj.weight", + "decoder.layers.95.self_attn.out_proj.bias", + "decoder.layers.95.self_attn_layer_norm.weight", + "decoder.layers.95.self_attn_layer_norm.bias", + "decoder.layers.95.fc1.weight", + "decoder.layers.95.fc1.bias", + "decoder.layers.95.fc2.weight", + "decoder.layers.95.fc2.bias", + "decoder.layers.95.final_layer_norm.weight", + "decoder.layers.95.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + } +} diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py index 0494647d7bcc..576daacdb471 100644 --- a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -1,7 +1,8 @@ import os -import torch from multiprocessing import Pool +import torch + # download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main # you can use whether wget or git lfs @@ -20,14 +21,14 @@ restored = {} for ckpt in ckpts: - for k,v in ckpt.items(): - if(k[0] == 'm'): - k = k[6:] - if(k == "lm_head.weight"): + for k, v in ckpt.items(): + if k[0] == "m": + k = k[6:] + if k == "lm_head.weight": k = "head.dense.weight" - if(k == "decoder.final_layer_norm.weight"): + if k == "decoder.final_layer_norm.weight": k = "decoder.layer_norm.weight" - if(k == "decoder.final_layer_norm.bias"): + if k == "decoder.final_layer_norm.bias": k = "decoder.layer_norm.bias" restored[k] = v restored["decoder.version"] = "0.0" @@ -37,11 +38,11 @@ count = 0 file_count = 1 tmp = {} -for k,v in restored.items(): +for k, v in restored.items(): print(k) tmp[k] = v - count = count + 1 - if(count == split_num): + count = count + 1 + if count == split_num: filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) file_count = file_count + 1 @@ -50,6 +51,3 @@ filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) - - - diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py index 8fbed6e83d52..75516bba560f 100644 --- a/examples/tutorial/opt/opt/colossalai_zero.py +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -4,7 +4,7 @@ # colossalai > 0.2.8 from colossalai.legacy.zero import TensorShardStrategy -zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), - tensor_placement_policy="auto", - reuse_fp16_shard=True), - optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) +zero = dict( + model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384), +) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py index dfcd3b382d3c..7172408f3cbc 100644 --- a/examples/tutorial/opt/opt/context.py +++ b/examples/tutorial/opt/opt/context.py @@ -4,7 +4,7 @@ from colossalai.legacy.core import global_context as gpc -class barrier_context(): +class barrier_context: """ This context manager is used to allow one process to execute while blocking all other processes in the same process group. This is often useful when downloading is required diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 8cbf3d2a2850..9bd23ffc8aba 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -86,14 +86,12 @@ def parse_args(): default=None, help="The configuration name of the dataset to use (via the datasets library).", ) - parser.add_argument("--train_file", - type=str, - default=None, - help="A csv or a json file containing the training data.") - parser.add_argument("--validation_file", - type=str, - default=None, - help="A csv or a json file containing the validation data.") + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) parser.add_argument( "--validation_split_percentage", default=5, @@ -161,10 +159,9 @@ def parse_args(): help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], ) - parser.add_argument("--num_warmup_steps", - type=int, - default=0, - help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( @@ -178,9 +175,11 @@ def parse_args(): "--block_size", type=int, default=None, - help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" - " this size for training. Default to the model max input length for single sentence inputs (take into" - " account special tokens)."), + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), ) parser.add_argument( "--preprocessing_num_workers", @@ -188,17 +187,16 @@ def parse_args(): default=None, help="The number of processes to use for the preprocessing.", ) - parser.add_argument("--overwrite_cache", - type=bool, - default=False, - help="Overwrite the cached training and evaluation sets") - parser.add_argument("--no_keep_linebreaks", - action="store_true", - help="Do not keep line breaks when using TXT files.") + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_model_id", - type=str, - help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument( "--checkpointing_steps", @@ -221,13 +219,15 @@ def parse_args(): "--report_to", type=str, default="all", - help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed."), + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), ) parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + parser.add_argument("--init_in_cpu", action="store_true", default=False, help="init training model in cpu") args = parser.parse_args() # Sanity checks @@ -250,6 +250,7 @@ def parse_args(): def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -257,7 +258,6 @@ def colo_memory_cap(size_in_GB): class DummyDataloader: - def __init__(self, length, batch_size, seq_len, vocab_size): self.length = length self.batch_size = batch_size @@ -380,32 +380,36 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") logger.info("Model config has been created", ranks=[0]) - if args.model_name_or_path == 'facebook/opt-13b': + if args.model_name_or_path == "facebook/opt-13b": tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) else: - print(f'load model from {args.model_name_or_path}') + print(f"load model from {args.model_name_or_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) if args.init_in_cpu: - init_dev = torch.device('cpu') + init_dev = torch.device("cpu") else: init_dev = get_current_device() cai_version = colossalai.__version__ - logger.info(f'using Colossal-AI version {cai_version}') + logger.info(f"using Colossal-AI version {cai_version}") # build model if version.parse(cai_version) >= version.parse("0.3.1"): from contextlib import nullcontext from colossalai.lazy import LazyInitContext - ctx = LazyInitContext( - default_device=init_dev - ) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext() + + ctx = ( + LazyInitContext(default_device=init_dev) + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b" + else nullcontext() + ) else: from colossalai.zero import ColoInitContext + ctx = ColoInitContext(device=init_dev) - if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b": # currently, there has a bug in pretrained opt-13b # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) @@ -414,17 +418,20 @@ def main(): else: logger.info("Finetune a pre-trained model", ranks=[0]) with ctx: - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) + model = OPTForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False, + ) # enable graident checkpointing model.gradient_checkpointing_enable() - PLACEMENT_POLICY = 'auto' + PLACEMENT_POLICY = "auto" if version.parse(cai_version) >= version.parse("0.3.1"): from colossalai.zero import GeminiDDP + model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True) elif version.parse(cai_version) > version.parse("0.1.10"): try: @@ -435,16 +442,19 @@ def main(): model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + chunk_manager = ChunkManager( + chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY), + ) gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) model = ZeroDDP(model, gemini_manager) - logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + logger.info(f"{model.__class__.__name__} has been created", ranks=[0]) if not args.synthetic: # Preprocessing the datasets. @@ -470,12 +480,15 @@ def tokenize_function(examples): if block_size > 1024: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) block_size = 1024 else: if args.block_size > tokenizer.model_max_length: - logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) block_size = min(args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. @@ -489,8 +502,8 @@ def group_texts(examples): total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { - k: [t[i:i + block_size] for i in range(0, total_length, block_size) - ] for k, t in concatenated_examples.items() + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result @@ -520,19 +533,23 @@ def group_texts(examples): # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: - train_dataloader = get_dataloader(train_dataset, - shuffle=True, - add_sampler=True, - collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size) - eval_dataloader = DataLoader(eval_dataset, - collate_fn=default_data_collator, - batch_size=args.per_device_eval_batch_size) + train_dataloader = get_dataloader( + train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size, + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) else: - train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) - eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) + train_dataloader = DummyDataloader( + 30, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) + eval_dataloader = DummyDataloader( + 10, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) logger.info("Dataloaders have been created", ranks=[0]) # Optimizer @@ -593,7 +610,6 @@ def group_texts(examples): global_step = 0 for epoch in range(starting_epoch, args.num_train_epochs): - if completed_steps >= args.max_train_steps: break @@ -601,7 +617,7 @@ def group_texts(examples): for step, batch in enumerate(train_dataloader): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(use_cache=False, **batch) - loss = outputs['loss'] + loss = outputs["loss"] optimizer.backward(loss) if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: @@ -624,7 +640,7 @@ def group_texts(examples): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(**batch) - loss = outputs['loss'].unsqueeze(0) + loss = outputs["loss"].unsqueeze(0) losses.append(loss) losses = torch.cat(losses) @@ -640,7 +656,7 @@ def group_texts(examples): if args.output_dir is not None: model_state = model.state_dict() if is_main_process: - torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + torch.save(model_state, args.output_dir + "/epoch_{}_model.pth".format(completed_steps)) dist.barrier() # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) # model.load_state_dict(load_state, strict=False) diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py index 887de7164e12..859f6e25e845 100644 --- a/examples/tutorial/sequence_parallel/config.py +++ b/examples/tutorial/sequence_parallel/config.py @@ -4,7 +4,7 @@ TRAIN_ITERS = 10 DECAY_ITERS = 4 WARMUP_FRACTION = 0.01 -GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU EVAL_ITERS = 10 EVAL_INTERVAL = 10 LR = 0.0001 @@ -28,8 +28,8 @@ NUM_MICRO_BATCHES = 4 # colossalai config -parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) +parallel = dict(pipeline=1, tensor=dict(size=2, mode="sequence")) fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) -gradient_handler = [dict(type='SequenceParallelGradientHandler')] +gradient_handler = [dict(type="SequenceParallelGradientHandler")] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py index 6fdf07ba5b69..137f3cf0267b 100644 --- a/examples/tutorial/sequence_parallel/data/__init__.py +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -15,16 +15,13 @@ def cyclic_iter(iter): yield x -def build_train_valid_test_data_iterators(train_iters, - global_batch_size, - eval_interval, - eval_iters, - dataloader_type='single', - **kwargs): +def build_train_valid_test_data_iterators( + train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs +): (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) logger = get_dist_logger() - logger.info('> building train, validation, and test datasets ...', ranks=[0]) + logger.info("> building train, validation, and test datasets ...", ranks=[0]) # Backward compatibility, assume fixed batch size. # if iteration > 0 and consumed_train_samples == 0: @@ -38,29 +35,29 @@ def build_train_valid_test_data_iterators(train_iters, # Data loader only on rank 0 of each model parallel group. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # Number of train/valid/test samples. train_samples = train_iters * global_batch_size eval_iters_ = (train_iters // eval_interval + 1) * eval_iters test_iters = eval_iters train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] - logger.info(' > datasets target sizes (minimum size):') - logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) - logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) - logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + logger.info(" > datasets target sizes (minimum size):") + logger.info(" train: {}".format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(" validation: {}".format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(" test: {}".format(train_val_test_num_samples[2]), ranks=[0]) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + train_valid_test_num_samples=train_val_test_num_samples, **kwargs + ) # Build dataloaders. dp_size = gpc.get_world_size(ParallelMode.DATA) - train_dataloader = build_pretraining_data_loader(train_ds, - consumed_samples=0, - micro_batch_size=global_batch_size // dp_size) - valid_dataloader = build_pretraining_data_loader(valid_ds, - consumed_samples=0, - micro_batch_size=global_batch_size // dp_size) + train_dataloader = build_pretraining_data_loader( + train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) + valid_dataloader = build_pretraining_data_loader( + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) # Flags to know if we need to do training/validation/testing. @@ -73,29 +70,26 @@ def build_train_valid_test_data_iterators(train_iters, flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. - torch.distributed.broadcast(flags, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Build iterators. dl_type = dataloader_type - assert dl_type in ['single', 'cyclic'] + assert dl_type in ["single", "cyclic"] if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(train_dataloader)) + train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) else: train_data_iterator = None if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(valid_dataloader)) + valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) else: valid_data_iterator = None if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(test_dataloader)) + test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) else: test_data_iterator = None diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py index b65ca1e64f3c..471be19bb123 100644 --- a/examples/tutorial/sequence_parallel/data/bert_helper.py +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data): if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: offset = 0 for key in keys: - assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" size = data[key].size() for i, s in enumerate(size): sizes[i + offset] = s @@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype): flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast - torch.distributed.broadcast(flatten_data, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Unpack output = {} @@ -93,7 +93,7 @@ def get_batch(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -104,12 +104,12 @@ def get_batch(data_iterator): data_b = broadcast_data(keys, data, datatype) # Unpack. - tokens = data_b['text'].long() - types = data_b['types'].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'].float() - lm_labels = data_b['labels'].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"].long() + types = data_b["types"].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"].float() + lm_labels = data_b["labels"].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask @@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator): global_rank = torch.distributed.get_rank() local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) local_rank = global_rank % local_world_size - seq_length = data_b['text'].size(1) + seq_length = data_b["text"].size(1) sub_seq_length = seq_length // local_world_size sub_seq_start = local_rank * sub_seq_length sub_seq_end = (local_rank + 1) * sub_seq_length # # # Unpack. - tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() - types = data_b['types'][:, sub_seq_start:sub_seq_end].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() - lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long() + types = data_b["types"][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask class SequenceParallelDataIterator: - def __init__(self, data_iter): self.data_iter = data_iter diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py index 70c1269122dc..afab202e0927 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -41,10 +41,19 @@ class BertDataset(Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, - short_seq_prob, seed, binary_head): - + def __init__( + self, + name, + indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + masked_lm_prob, + max_seq_length, + short_seq_prob, + seed, + binary_head, + ): # Params to store. self.name = name self.seed = seed @@ -61,11 +70,12 @@ def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_sampl data_prefix, num_epochs, max_num_samples, - self.max_seq_length - 3, # account for added tokens, + self.max_seq_length - 3, # account for added tokens, short_seq_prob, self.seed, self.name, - self.binary_head) + self.binary_head, + ) # Vocab stuff. tokenizer = get_tokenizer() @@ -89,7 +99,7 @@ def __getitem__(self, idx): return build_training_sample( sample, seq_length, - self.max_seq_length, # needed for padding + self.max_seq_length, # needed for padding self.vocab_id_list, self.vocab_id_to_token_dict, self.cls_id, @@ -98,37 +108,39 @@ def __getitem__(self, idx): self.pad_id, self.masked_lm_prob, np_rng, - self.binary_head) + self.binary_head, + ) -def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, - seed, name, binary_head): +def get_samples_mapping_( + indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head +): logger = get_dist_logger() if not num_epochs: if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") + raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) + indexmap_filename += "_{}_indexmap".format(name) if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) + indexmap_filename += "_{}ep".format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) - indexmap_filename += '_{}s'.format(seed) - indexmap_filename += '.npy' + indexmap_filename += "_{}mns".format(max_num_samples) + indexmap_filename += "_{}msl".format(max_seq_length) + indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob) + indexmap_filename += "_{}s".format(seed) + indexmap_filename += ".npy" # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) + if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + print( + " > WARNING: could not find index map file {}, building " + "the indices on rank 0 ...".format(indexmap_filename) + ) # Make sure the types match the helpers input types. assert indexed_dataset.doc_idx.dtype == np.int64 @@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl # Build samples mapping verbose = torch.distributed.get_rank() == 0 start_time = time.time() - logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0]) # First compile and then import. - samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, - max_num_samples, max_seq_length, short_seq_prob, seed, verbose, - 2 if binary_head else 1) - logger.info('\n > done building samples index maping', ranks=[0]) + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1, + ) + logger.info("\n > done building samples index maping", ranks=[0]) np.save(indexmap_filename, samples_mapping, allow_pickle=True) - logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0]) # Make sure all the ranks have built the mapping - logger.info('\n > elapsed time to build and save samples mapping ' - '(seconds): {:4f}'.format(time.time() - start_time), - ranks=[0]) + logger.info( + "\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time), + ranks=[0], + ) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case @@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) if gpc.is_initialized(ParallelMode.PIPELINE): torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) - assert counts[0].item() == (torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)) + ) # Load indexed dataset. start_time = time.time() - samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + - '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + - '\n total number of samples: {}'.format(samples_mapping.shape[0]), - ranks=[0]) + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r") + logger.info( + "\n > loading indexed mapping from {}".format(indexmap_filename) + + "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + + "\n total number of samples: {}".format(samples_mapping.shape[0]), + ranks=[0], + ) return samples_mapping -def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, - sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): +def build_training_sample( + sample, + target_seq_length, + max_seq_length, + vocab_id_list, + vocab_id_to_token_dict, + cls_id, + sep_id, + mask_id, + pad_id, + masked_lm_prob, + np_rng, + binary_head, +): """Build training sample. Arguments: @@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li # Masking. max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, - _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, - mask_id, max_predictions_per_seq, np_rng) + (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + ) # Padding. - tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ - = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length) + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy( + tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length + ) train_sample = { - 'text': tokens_np, - 'types': tokentypes_np, - 'labels': labels_np, - 'is_random': int(is_next_random), - 'loss_mask': loss_mask_np, - 'padding_mask': padding_mask_np, - 'truncated': int(truncated) + "text": tokens_np, + "types": tokentypes_np, + "labels": labels_np, + "is_random": int(is_next_random), + "loss_mask": loss_mask_np, + "padding_mask": padding_mask_np, + "truncated": int(truncated), } return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py index 6a06c869d8c8..1fa9c85fce0a 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -22,9 +22,7 @@ class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights): - self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) @@ -46,12 +44,16 @@ def __init__(self, datasets, weights): self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) from . import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time)) def __len__(self): return self.size diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py index 6106f833b462..edf4c3d70cbf 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/builder.py +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -1,29 +1,34 @@ +from colossalai.logging import get_dist_logger + +from .bert_dataset import BertDataset from .blendable_dataset import BlendableDataset from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ -from .bert_dataset import BertDataset -from colossalai.logging import get_dist_logger -DSET_TYPE_BERT = 'standard_bert' -DSET_TYPE_ICT = 'ict' -DSET_TYPE_T5 = 't5' +DSET_TYPE_BERT = "standard_bert" +DSET_TYPE_ICT = "ict" +DSET_TYPE_T5 = "t5" DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, logger = get_dist_logger() # Print stats about the splits. - logger.info('\n > dataset split:', ranks=[0]) + logger.info("\n > dataset split:", ranks=[0]) def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None @@ -80,44 +88,53 @@ def build_dataset(index, name): masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, binary_head=binary_head, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py index b9c197c95ae3..8ba598529ebc 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -14,7 +14,6 @@ # limitations under the License. """Dataloaders.""" -import random import torch @@ -22,61 +21,60 @@ from colossalai.legacy.core import global_context as gpc -def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0): """Build dataloader given an input dataset.""" if dataset is None: return None # Megatron sampler - if dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) - elif dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + if dataloader_type == "single": + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) + elif dataloader_type == "cyclic": + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) else: - raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + raise Exception("{} dataloader type is not supported.".format(dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) class MegatronPretrainingSampler: - - def __init__(self, - total_samples, - consumed_samples, - micro_batch_size, - data_parallel_rank, - data_parallel_size, - drop_last=True): + def __init__( + self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True + ): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.consumed_samples < self.total_samples, \ - 'no samples left to consume: {}, {}'.format(self.consumed_samples, - self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) + assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( + self.consumed_samples, self.total_samples + ) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -103,7 +101,6 @@ def __iter__(self): class MegatronPretrainingRandomSampler: - def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples @@ -111,19 +108,18 @@ def __init__(self, total_samples, consumed_samples, micro_batch_size, data_paral self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -135,8 +131,7 @@ def __iter__(self): assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ - * self.micro_batch_size + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py index cf4e4763fc10..3e197ff96c0c 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -18,32 +18,33 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. +import collections import math import time -import collections -from colossalai.logging import get_dist_logger + import numpy as np + +from colossalai.logging import get_dist_logger + from .blendable_dataset import BlendableDataset from .indexed_dataset import make_dataset as make_indexed_dataset -DSET_TYPE_STD = 'standard_bert' -DSET_TYPE_ICT = 'ict' +DSET_TYPE_STD = "standard_bert" +DSET_TYPE_ICT = "ict" DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): - +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 - weights = [0]*num_datasets - prefixes = [0]*num_datasets + weights = [0] * num_datasets + prefixes = [0] * num_datasets for i in range(num_datasets): - weights[i] = float(data_prefix[2*i]) - prefixes[i] = (data_prefix[2*i+1]).strip() + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: @@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix, datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) - for val in train_valid_test_num_samples]) + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] + ) return prefixes, weights, datasets_train_valid_test_num_samples @@ -68,11 +69,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) - ret = subprocess.run(['make', '-C', path]) + ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) @@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng): # Number of sentences in the sample. n_sentences = len(sample) # Make sure we always have two sentences. - assert n_sentences > 1, 'make sure each sample has at least two sentences.' + assert n_sentences > 1, "make sure each sample has at least two sentences." # First part: # `a_end` is how many sentences go into the `A`. @@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): """Truncates a pair of sequences to a maximum sequence length.""" - #print(len_a, len_b, max_num_tokens) + # print(len_a, len_b, max_num_tokens) assert len_a > 0 if len_a + len_b <= max_num_tokens: return False @@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): return tokens, tokentypes -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): @@ -168,16 +170,21 @@ def is_start_piece(piece): return not piece.startswith("##") -def create_masked_lm_predictions(tokens, - vocab_id_list, vocab_id_to_token_dict, - masked_lm_prob, - cls_id, sep_id, mask_id, - max_predictions_per_seq, - np_rng, - max_ngrams=3, - do_whole_word_mask=True, - favor_longer_ngram=False, - do_permutation=False): +def create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, +): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" @@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens, # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue @@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens, # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (do_whole_word_mask and len(cand_indexes) >= 1 and - not is_start_piece(vocab_id_to_token_dict[token])): + if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens, masked_lm_labels = [] if masked_lm_prob == 0: - return (output_tokens, masked_lm_positions, - masked_lm_labels, token_boundary) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) # Note(mingdachen): # By default, we set the probabilities to favor shorter ngram sequences. ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) - pvals = 1. / np.arange(1, max_ngrams + 1) + pvals = 1.0 / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: @@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens, for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: - ngram_index.append(cand_indexes[idx:idx + n]) + ngram_index.append(cand_indexes[idx : idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) @@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes: continue - n = np_rng.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np_rng.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): @@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes or index in select_indexes: continue - n = np.random.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np.random.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 @@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. @@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length @@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) - - -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): logger = get_dist_logger() if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) if dataset_type == DSET_TYPE_ICT: args = get_args() - title_dataset = get_indexed_dataset_(args.titles_data_path, - data_impl, - skip_warmup) + title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - logger.info('\n > dataset split:') + logger.info("\n > dataset split:") def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], - splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, - end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): from .bert_dataset import BertDataset + dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. @@ -508,7 +534,7 @@ def build_dataset(index, name): max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, seed=seed, - binary_head=binary_head + binary_head=binary_head, ) if dataset_type == DSET_TYPE_ICT: @@ -518,27 +544,26 @@ def build_dataset(index, name): title_dataset=title_dataset, query_in_block_prob=args.query_in_block_prob, use_one_sent_docs=args.use_one_sent_docs, - **kwargs + **kwargs, ) else: dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) @@ -546,44 +571,41 @@ def build_dataset(index, name): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): logger = get_dist_logger() start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] - logger.info('\n > building dataset index ...', ranks=[0]) - logger.info('\n > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time), ranks=[0]) - logger.info('\n > indexed dataset stats:' + - '\n number of documents: {}'.format( - indexed_dataset.doc_idx.shape[0] - 1) + - '\n number of sentences: {}'.format( - indexed_dataset.sizes.shape[0]), - ranks=[0] - ) + logger.info("\n > building dataset index ...", ranks=[0]) + logger.info( + "\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0] + ) + logger.info( + "\n > indexed dataset stats:" + + "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1) + + "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]), + ranks=[0], + ) return indexed_dataset def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" + """Get dataset splits from comma or '/' separated string list.""" splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] else: splits = [float(splits_string)] while len(splits) < 3: - splits.append(0.) + splits.append(0.0) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) + splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index e45926a97696..52977e63181f 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -15,29 +15,28 @@ limitations under the License. */ - /* Helper methods for fast index mapping builds */ +#include +#include +#include + #include #include #include -#include -#include -#include -#include #include +#include namespace py = pybind11; using namespace std; const int32_t LONG_SENTENCE_LEN = 512; - void build_blending_indices(py::array_t& dataset_index, - py::array_t& dataset_sample_index, - const py::array_t& weights, - const int32_t num_datasets, - const int64_t size, const bool verbose) { + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, const int64_t size, + const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ @@ -52,24 +51,23 @@ void build_blending_indices(py::array_t& dataset_index, // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; - for(int64_t i = 0; i < num_datasets; ++i) { + for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: - for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { - + for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); + static_cast(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); + static_cast(current_samples[dataset_idx]); if (error > max_error) { - max_error = error; - max_error_index = dataset_idx; + max_error = error; + max_error_index = dataset_idx; } } @@ -79,7 +77,6 @@ void build_blending_indices(py::array_t& dataset_index, // Update the total samples. current_samples[max_error_index] += 1; - } // print info @@ -87,631 +84,607 @@ void build_blending_indices(py::array_t& dataset_index, std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << - weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + static_cast(size); + std::cout << " dataset " << dataset_idx + << ", input: " << weights_ptr[dataset_idx] + << ", achieved: " << ratio << std::endl; } } - } - py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) { - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ - - // Consistency checks. - assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); - - // Remove bound checks. - auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); - - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t* sample_idx = new int32_t[2*(num_samples+1)]; - - cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; - - // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; - // Begining offset for each document. - int32_t doc_offset = 0; - // Start with first document and no offset. + const py::array_t& doc_idx_, + const int32_t seq_length, const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs + << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " sequence length: " << seq_length << endl + << std::flush; + cout << " total number of samples: " << num_samples << endl + << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. sample_idx[2 * sample_index] = doc_idx_index; sample_idx[2 * sample_index + 1] = doc_offset; ++sample_index; + } - while (sample_index <= num_samples) { - // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { - // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - } - - // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples+1, 2}, // shape - {2*byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references - + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void* mem_) { + int32_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references } - inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937& rand32_gen) { - /* Training sample length. */ - if (short_seq_ratio == 0) { - return max_length; - } - const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) { - return 2 + random_number % (max_length - 1); - } + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; } - -template +template py::array build_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const int32_t num_epochs, const uint64_t max_num_samples, const int32_t max_seq_length, - const double short_seq_prob, - const int32_t seed, - const bool verbose, - const int32_t min_num_sent) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(short_seq_prob >= 0.0); - assert(short_seq_prob <= 1.0); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - - // For efficiency, convert probability to ratio. Note: rand() generates int. - int32_t short_seq_ratio = 0; - if (short_seq_prob > 0) { - short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); - } + const double short_seq_prob, const int32_t seed, + const bool verbose, const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " short sequence probability: " << short_seq_prob << - endl << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " short sequence probability: " << short_seq_prob << endl + << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and it's length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the seed so both iterations produce the same results. - std::mt19937 rand32_gen(seed); - - // Set the flag on second iteration. - second = (iteration == 1); - - // Counters: - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - - // Current map index. - uint64_t map_index = 0; - - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - - // If we have more than two sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3*map_index]; - num_samples = static_cast(map_index); + } } - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - } + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Check for overflow. + if ((3 * map_index + 2) > std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + seq_len = 0; + num_sent = 0; + } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3 * map_index]; + num_samples = static_cast(map_index); + } - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } - py::array build_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const int num_epochs, + const py::array_t& sizes_, const int num_epochs, const uint64_t max_num_samples, - const int max_seq_length, - const double short_seq_prob, - const int seed, - const bool verbose, - const int32_t min_num_sent) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); + const int max_seq_length, const double short_seq_prob, + const int seed, const bool verbose, + const int32_t min_num_sent) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } } -template -py::array build_blocks_mapping_impl(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const int32_t seed, - const bool verbose, - const bool use_one_sent_blocks) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - auto titles_sizes = titles_sizes_.unchecked<1>(); +template +py::array build_blocks_mapping_impl( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int32_t num_epochs, + const uint64_t max_num_samples, const int32_t max_seq_length, + const int32_t seed, const bool verbose, const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and its length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; - // Acceptable number of sentences per block. - int min_num_sent = 2; - if (use_one_sent_blocks) { - min_num_sent = 1; - } + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the flag on second iteration. - second = (iteration == 1); - - // Current map index. - uint64_t map_index = 0; - - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; - } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - const auto target_seq_len = max_seq_length - titles_sizes[doc]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent >= min_num_sent) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - // If we have enough sentences and no long sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and there are an acceptable number of sentences left - // and if we have at least the minimum number of sentences. - // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Populate the map. - if (second) { - const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) - - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(doc); - maps[map_index_0 + 3] = static_cast(block_id); - } - - // Update indices / counters. - ++map_index; - ++block_id; - prev_start_index = sent_index + 1; - seq_len = 0; - num_sent = 0; - } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + // assign every block a unique id + int32_t block_id = 0; + + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[4*map_index]; - num_samples = static_cast(map_index); + } } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 4 * i; - const auto j0 = 4 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - swap(maps[i0 + 3], maps[j0 + 3]); + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending + // sentence index, the index of the document from which the + // block comes (used for fetching titles) and the unique id of + // the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4 * map_index]; + num_samples = static_cast(map_index); } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } -py::array build_blocks_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const int seed, - const bool verbose, - const bool use_one_sent_blocks) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); +py::array build_blocks_mapping( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int num_epochs, + const uint64_t max_num_samples, const int max_seq_length, const int seed, + const bool verbose, const bool use_one_sent_blocks) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } } PYBIND11_MODULE(helpers, m) { - m.def("build_mapping", &build_mapping); - m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); - m.def("build_blending_indices", &build_blending_indices); + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); } diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py index 6dac35ff9d41..220099f9ba32 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -2,12 +2,11 @@ import random import numpy as np -from torch.utils.data import Dataset - -from megatron import get_tokenizer -from megatron import get_args +from megatron import get_args, get_tokenizer from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping +from torch.utils.data import Dataset + def make_attention_mask(source_block, target_block): """ @@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block): # (source_length, target_length) return mask + def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() - block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) - titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + block_dataset = get_indexed_dataset_(args.data_path, "mmap", True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True) kwargs = dict( - name='full', + name="full", block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, @@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, - use_one_sent_docs=args.use_one_sent_docs + use_one_sent_docs=args.use_one_sent_docs, ) dataset = ICTDataset(**kwargs) return dataset @@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" - def __init__(self, name, block_dataset, title_dataset, data_prefix, - num_epochs, max_num_samples, max_seq_length, query_in_block_prob, - seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + + def __init__( + self, + name, + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + query_in_block_prob, + seed, + use_titles=True, + use_one_sent_docs=False, + binary_head=False, + ): self.name = name self.seed = seed self.max_seq_length = max_seq_length @@ -61,8 +74,16 @@ def __init__(self, name, block_dataset, title_dataset, data_prefix, self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( - block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + seed, + name, + use_one_sent_docs, + ) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab @@ -99,8 +120,8 @@ def __getitem__(self, idx): # still need to truncate because blocks are concluded when # the sentence lengths have exceeded max_seq_length. - query = query[:self.max_seq_length - 2] - block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + query = query[: self.max_seq_length - 2] + block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) @@ -111,13 +132,13 @@ def __getitem__(self, idx): block_data = sample_data.as_array() sample = { - 'query_tokens': query_tokens, - 'query_mask': query_mask, - 'query_pad_mask': query_pad_mask, - 'context_tokens': context_tokens, - 'context_mask': context_mask, - 'context_pad_mask': context_pad_mask, - 'block_data': block_data, + "query_tokens": query_tokens, + "query_mask": query_mask, + "query_pad_mask": query_pad_mask, + "context_tokens": context_tokens, + "context_mask": context_mask, + "context_pad_mask": context_pad_mask, + "block_data": block_data, } return sample @@ -127,7 +148,7 @@ def get_block(self, start_idx, end_idx, doc_idx): block = [self.block_dataset[i] for i in range(start_idx, end_idx)] title = self.title_dataset[int(doc_idx)] - block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py index 9a25dc453c24..961a1650bd74 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['lazy', 'cached', 'mmap'] + return ["lazy", "cached", "mmap"] def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: @@ -47,7 +47,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False): print(f"Dataset does not exist: {path}") print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -98,11 +98,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -115,7 +115,8 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() @@ -124,27 +125,28 @@ def __init__(self, path): self.read_index(path) def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) - assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.') + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." + ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -157,7 +159,7 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -166,7 +168,7 @@ def __getitem__(self, idx): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -186,15 +188,14 @@ def size(self, index): @staticmethod def exists(path): - return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))) + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) @property def supports_prefetch(self): - return False # avoid prefetching to save memory + return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -219,7 +220,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx:ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -233,10 +234,10 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx:ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object): element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8} def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] @@ -280,7 +281,7 @@ def merge_file_(self, another_file): for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -290,12 +291,12 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # return True return False @@ -320,7 +320,7 @@ def _clean_text(self, text): output = [] for char in text: cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): + if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") @@ -422,8 +422,7 @@ def _is_punctuation(char): # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py index ba832b5cdce9..6c7bfd69f3f9 100644 --- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -25,16 +25,15 @@ def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): """Initialize tokenizer.""" if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print('> building {} tokenizer ...'.format(tokenizer_type), flush=True) + print("> building {} tokenizer ...".format(tokenizer_type), flush=True) # Select and instantiate the tokenizer. - if tokenizer_type == 'BertWordPieceLowerCase': + if tokenizer_type == "BertWordPieceLowerCase": tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) - elif tokenizer_type == 'BertWordPieceCase': + elif tokenizer_type == "BertWordPieceCase": tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) else: - raise NotImplementedError('{} tokenizer is not ' - 'implemented.'.format(tokenizer_type)) + raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type)) # Add vocab size. padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) @@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): while (after % multiple) != 0: after += 1 if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after), - flush=True) + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) return after @@ -77,46 +78,38 @@ def vocab_size(self): @abstractmethod def vocab(self): """Dictionary from vocab text token to id token.""" - pass @property @abstractmethod def inv_vocab(self): """Dictionary from vocab id token to text token.""" - pass @abstractmethod def tokenize(self, text): pass def detokenize(self, token_ids): - raise NotImplementedError('detokenizer is not implemented for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name)) @property def cls(self): - raise NotImplementedError('CLS is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name)) @property def sep(self): - raise NotImplementedError('SEP is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name)) @property def pad(self): - raise NotImplementedError('PAD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name)) @property def eod(self): - raise NotImplementedError('EOD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name)) @property def mask(self): - raise NotImplementedError('MASK is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name)) class _BertWordPieceTokenizer(AbstractTokenizer): @@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer): def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): if lower_case: - name = 'BERT Lower Case' + name = "BERT Lower Case" else: - name = 'BERT Upper Case' + name = "BERT Upper Case" super().__init__(name) self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) - self.cls_id = self.tokenizer.vocab['[CLS]'] - self.sep_id = self.tokenizer.vocab['[SEP]'] - self.pad_id = self.tokenizer.vocab['[PAD]'] - self.mask_id = self.tokenizer.vocab['[MASK]'] + self.cls_id = self.tokenizer.vocab["[CLS]"] + self.sep_id = self.tokenizer.vocab["[SEP]"] + self.pad_id = self.tokenizer.vocab["[PAD]"] + self.mask_id = self.tokenizer.vocab["[MASK]"] self._additional_special_tokens = [] # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'} - self._bos_token = '[BOS]' + SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"} + self._bos_token = "[BOS]" self.add_token(self._bos_token) self._bos_token_id = self.vocab.get(self._bos_token) - self._eos_token = '[EOS]' + self._eos_token = "[EOS]" self.add_token(self._eos_token) self._eos_token_id = self.vocab.get(self._eos_token) @@ -185,7 +178,7 @@ def decode(self, ids): def decode_token_ids(self, token_ids): tokens = self.tokenizer.convert_ids_to_tokens(token_ids) - exclude_list = ['[PAD]', '[CLS]'] + exclude_list = ["[PAD]", "[CLS]"] non_pads = [t for t in tokens if t not in exclude_list] result = "" @@ -215,32 +208,32 @@ def mask(self): @property def bos_token(self): - """ Beginning of sentence token id """ + """Beginning of sentence token id""" return self._bos_token @property def eos_token(self): - """ End of sentence token id """ + """End of sentence token id""" return self._eos_token @property def additional_special_tokens(self): - """ All the additional special tokens you may want to use (list of strings).""" + """All the additional special tokens you may want to use (list of strings).""" return self._additional_special_tokens @property def bos_token_id(self): - """ Id of the beginning of sentence token in the vocabulary.""" + """Id of the beginning of sentence token in the vocabulary.""" return self._bos_token_id @property def eos_token_id(self): - """ Id of the end of sentence token in the vocabulary.""" + """Id of the end of sentence token in the vocabulary.""" return self._eos_token_id @property def additional_special_tokens_ids(self): - """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + """Ids of all the additional special tokens in the vocabulary (list of integers).""" return [self.vocab.get(token) for token in self._additional_special_tokens] @additional_special_tokens.setter diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py index b3f2487a438b..869ff720f4b0 100644 --- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -1,17 +1,12 @@ import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.logging import get_dist_logger - -from .cross_entropy import vocab_cross_entropy class BertLoss(nn.Module): - def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py index ed15c6ea8054..b5d9ea919261 100644 --- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -1,11 +1,8 @@ import torch from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.legacy.context.parallel_mode import ParallelMode - class _VocabCrossEntropy(torch.autograd.Function): - @staticmethod @custom_fwd def forward(ctx, vocab_parallel_logits, target): @@ -59,7 +56,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py index a3d92f294326..35fa73896c46 100644 --- a/examples/tutorial/sequence_parallel/loss_func/utils.py +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -1,11 +1,9 @@ - import torch def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, '{} is not divisible by {}'.format( - numerator, denominator) + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) def divide(numerator, denominator): @@ -15,8 +13,7 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -38,12 +35,11 @@ def split_tensor_along_last_dim(tensor, num_partitions, class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod - def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, world_size): + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @@ -51,5 +47,4 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, @staticmethod def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py index 8d95679ff76d..866d0d54583b 100644 --- a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -21,16 +21,17 @@ class AnnealingLR(object): """Anneals the learning rate.""" - def __init__(self, - optimizer, - max_lr, - min_lr, - warmup_steps, - decay_steps, - decay_style, - use_checkpoint_lr_scheduler=True, - override_lr_scheduler=False): - + def __init__( + self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False, + ): # Class values. self.optimizer = optimizer @@ -50,23 +51,21 @@ def __init__(self, self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: - assert not self.use_checkpoint_lr_scheduler, 'both override and '\ - 'use-checkpoint are set.' + assert not self.use_checkpoint_lr_scheduler, "both override and " "use-checkpoint are set." # Set the learning rate self.step(0) def get_lr(self): """Learning rate decay functions from: - https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: - return self.max_lr * float(self.num_steps) / \ - float(self.warmup_steps) + return self.max_lr * float(self.num_steps) / float(self.warmup_steps) # If the learning rate is constant, just return the initial value. - if self.decay_style == 'constant': + if self.decay_style == "constant": return self.max_lr # For any steps larger than `self.decay_steps`, use `self.min_lr`. @@ -81,13 +80,12 @@ def get_lr(self): assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr - if self.decay_style == 'linear': - coeff = (1.0 - decay_ratio) - elif self.decay_style == 'cosine': + if self.decay_style == "linear": + coeff = 1.0 - decay_ratio + elif self.decay_style == "cosine": coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: - raise Exception('{} decay style is not supported.'.format( - self.decay_style)) + raise Exception("{} decay style is not supported.".format(self.decay_style)) return self.min_lr + coeff * delta_lr @@ -96,16 +94,16 @@ def step(self, increment=1): self.num_steps += increment new_lr = self.get_lr() for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr def state_dict(self): state_dict = { - 'max_lr': self.max_lr, - 'warmup_steps': self.warmup_steps, - 'num_steps': self.num_steps, - 'decay_style': self.decay_style, - 'decay_steps': self.decay_steps, - 'min_lr': self.min_lr + "max_lr": self.max_lr, + "warmup_steps": self.warmup_steps, + "num_steps": self.num_steps, + "decay_style": self.decay_style, + "decay_steps": self.decay_steps, + "min_lr": self.min_lr, } return state_dict @@ -116,43 +114,35 @@ def _check_and_set(self, cls_value, sd_value, name): return cls_value if not self.use_checkpoint_lr_scheduler: - assert cls_value == sd_value, \ - f'AnnealingLR: class input value {cls_value} and checkpoint' \ - f'value {sd_value} for {name} do not match' + assert cls_value == sd_value, ( + f"AnnealingLR: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match" + ) return sd_value def load_state_dict(self, sd): - - if 'start_lr' in sd: - max_lr_ = sd['start_lr'] + if "start_lr" in sd: + max_lr_ = sd["start_lr"] else: - max_lr_ = sd['max_lr'] - self.max_lr = self._check_and_set(self.max_lr, max_lr_, - 'learning rate') + max_lr_ = sd["max_lr"] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate") - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') + self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate") - if 'warmup_iter' in sd: - warmup_steps_ = sd['warmup_iter'] + if "warmup_iter" in sd: + warmup_steps_ = sd["warmup_iter"] else: - warmup_steps_ = sd['warmup_steps'] - self.warmup_steps = self._check_and_set(self.warmup_steps, - warmup_steps_, - 'warmup iterations') + warmup_steps_ = sd["warmup_steps"] + self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, "warmup iterations") - if 'end_iter' in sd: - decay_steps_ = sd['end_iter'] + if "end_iter" in sd: + decay_steps_ = sd["end_iter"] else: - decay_steps_ = sd['decay_steps'] - self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, - 'total number of iterations') - self.decay_style = self._check_and_set(self.decay_style, - sd['decay_style'], - 'decay style') - - if 'num_iters' in sd: - num_steps = sd['num_iters'] + decay_steps_ = sd["decay_steps"] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, "total number of iterations") + self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style") + + if "num_iters" in sd: + num_steps = sd["num_iters"] else: - num_steps = sd['num_steps'] + num_steps = sd["num_steps"] self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py index 139597f9cb07..e69de29bb2d1 100644 --- a/examples/tutorial/sequence_parallel/model/__init__.py +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -1,2 +0,0 @@ - - diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 4ba64bbe2b9f..7b0e93d958ca 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -16,7 +16,6 @@ class BertForPretrain(nn.Module): - def __init__( self, vocab_size, @@ -34,7 +33,9 @@ def __init__( ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -43,28 +44,32 @@ def __init__( num_tokentypes = 0 self.preprocessor = PreProcessor(self.sub_seq_length) - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) self.bert_layers = nn.ModuleList() for i in range(num_layers): - bert_layer = BertLayer(layer_number=i + 1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, - self.embedding.word_embedding_weight.size(0), - add_binary_head=add_binary_head) + self.head = BertDualHead( + hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head + ) self.reset_parameters() def _init_normal(self, tensor): @@ -122,27 +127,30 @@ def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): class PipelineBertForPretrain(nn.Module): - - def __init__(self, - vocab_size, - hidden_size, - max_sequence_length, - num_attention_heads, - num_layers, - add_binary_head, - is_naive_fp16, - num_tokentypes=2, - dropout_prob=0.1, - mlp_ratio=4, - init_std=0.02, - convert_fp16_to_fp32_in_softmax=False, - first_stage=True, - last_stage=True, - start_idx=None, - end_idx=None): + def __init__( + self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None, + ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -156,11 +164,13 @@ def __init__(self, self.preprocessor = PreProcessor(self.sub_seq_length) if self.first_stage: - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) # transformer layers self.bert_layers = nn.ModuleList() @@ -170,14 +180,16 @@ def __init__(self, end_idx = num_layers for i in range(start_idx, end_idx): - bert_layer = BertLayer(layer_number=i + 1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) if self.last_stage: @@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = num_layers - kwargs['start_idx'] = start - kwargs['end_idx'] = end - kwargs['first_stage'] = start == 0 - kwargs['last_stage'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = num_layers + kwargs["start_idx"] = start + kwargs["end_idx"] = end + kwargs["first_stage"] = start == 0 + kwargs["last_stage"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) if start == 0: wrapper.register_module(chunk.embedding.word_embeddings) diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py index 3a8823caa81b..58495c516239 100644 --- a/examples/tutorial/sequence_parallel/model/layers/__init__.py +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -1,4 +1,4 @@ -from .embedding import VocabEmbedding, Embedding from .bert_layer import BertLayer +from .embedding import Embedding, VocabEmbedding from .head import BertDualHead from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py index 56ba511d8274..1ef16ee6ad79 100644 --- a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -20,18 +20,20 @@ class BertLayer(nn.Module): output of the same size. """ - def __init__(self, - layer_number, - hidden_size, - num_attention_heads, - attention_dropout, - mlp_ratio, - hidden_dropout, - is_naive_fp16, - apply_residual_connection_post_layernorm=False, - fp32_residual_connection=False, - bias_dropout_fusion: bool = True, - convert_fp16_to_fp32_in_softmax: bool = False): + def __init__( + self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False, + ): super().__init__() self.layer_number = layer_number @@ -50,7 +52,8 @@ def __init__(self, layer_number=layer_number, apply_query_key_layer_scaling=True, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - fp16=is_naive_fp16) + fp16=is_naive_fp16, + ) self.hidden_dropout = hidden_dropout self.bias_dropout_fusion = bias_dropout_fusion @@ -90,8 +93,9 @@ def forward(self, hidden_states, attention_mask): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual, - self.hidden_dropout) + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout + ) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py index 0e99105b8f7e..18eae0d63cd1 100644 --- a/examples/tutorial/sequence_parallel/model/layers/dropout.py +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -1,5 +1,6 @@ import torch + def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) @@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training): def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) - return _bias_dropout_add \ No newline at end of file + + return _bias_dropout_add diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py index 0700d960d845..03183c55243f 100644 --- a/examples/tutorial/sequence_parallel/model/layers/embedding.py +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -5,7 +5,6 @@ class VocabEmbedding(torch.nn.Module): - def __init__(self, num_embeddings, embedding_dim): super(VocabEmbedding, self).__init__() # Keep the input dimensions. @@ -13,26 +12,29 @@ def __init__(self, num_embeddings, embedding_dim): self.embedding_dim = embedding_dim self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - self.weight = nn.Parameter(torch.empty( - self.num_embeddings, self.embedding_dim)) + self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim)) init.xavier_uniform_(self.weight) def forward(self, hidden_state): - output = F.embedding(hidden_state, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) + output = F.embedding( + hidden_state, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return output def __repr__(self): - return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ - f'embedding_dim={self.embedding_dim})' + return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})" class Embedding(nn.Module): @@ -48,12 +50,7 @@ class Embedding(nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes): + def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -62,16 +59,14 @@ def __init__(self, self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) else: self.tokentype_embeddings = None diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 9e25157e1b40..75afeee60ad4 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,12 +3,10 @@ import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -import colossalai from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -from .embedding import VocabEmbedding from .linear import Linear from .pooler import Pooler @@ -26,7 +24,6 @@ def __init__( vocab_size, hidden_size, ): - super(BertLMHead, self).__init__() self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) @@ -46,7 +43,6 @@ def forward(self, hidden_states, word_embeddings_weight, lm_labels): class BertBinaryHead(nn.Module): - def __init__(self, hidden_size): super().__init__() self.pooler = Pooler(hidden_size) @@ -62,7 +58,6 @@ def forward(self, hidden_states): class BertDualHead(nn.Module): - def __init__(self, hidden_size, vocab_size, add_binary_head): super().__init__() self.lm_head = BertLMHead(vocab_size, hidden_size) diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py index 1b409dfe4054..22d12a504fab 100644 --- a/examples/tutorial/sequence_parallel/model/layers/init_method.py +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -1,6 +1,8 @@ -import torch import math +import torch + + def init_normal(tensor, sigma): """Init method based on N(0, sigma).""" torch.nn.init.normal_(tensor, mean=0.0, std=sigma) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py index 5ae7d671e2bf..5592f6e8c209 100644 --- a/examples/tutorial/sequence_parallel/model/layers/linear.py +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.nn.init as init +from torch.nn import Parameter class Linear(nn.Module): @@ -24,11 +24,7 @@ class Linear(nn.Module): adding bias but instead return it. """ - def __init__(self, - input_size, - output_size, - bias=True, - skip_bias_add=False): + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(Linear, self).__init__() # Keep input parameters @@ -36,9 +32,12 @@ def __init__(self, self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty(self.output_size, - self.input_size, - )) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) init.normal_(self.weight) if bias: self.bias = Parameter(torch.empty(self.output_size)) @@ -46,7 +45,7 @@ def __init__(self, with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): # Matrix multiply. @@ -59,5 +58,7 @@ def forward(self, input_): return output def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py index a255de813d13..54a695fda402 100644 --- a/examples/tutorial/sequence_parallel/model/layers/mlp.py +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -1,10 +1,10 @@ -import torch import torch.nn as nn import torch.nn.functional as F -from .linear import Linear from colossalai.kernel.jit import bias_gelu_impl +from .linear import Linear + class TransformerMLP(nn.Module): """MLP. @@ -18,19 +18,13 @@ def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): super(TransformerMLP, self).__init__() # Project to 4h. - self.dense_h_to_4h = Linear( - hidden_size, - int(hidden_size*mlp_ratio), - skip_bias_add=True) + self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True) self.bias_gelu_fusion = fuse_gelu self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear( - int(hidden_size*mlp_ratio), - hidden_size, - skip_bias_add=True) + self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True) def forward(self, hidden_states): # hidden states should be in the shape of [s, b, h] @@ -39,11 +33,9 @@ def forward(self, hidden_states): intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: - intermediate_parallel = \ - bias_gelu_impl(intermediate_parallel, bias_parallel) + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) else: - intermediate_parallel = \ - self.activation_func(intermediate_parallel + bias_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py index 282ed114790b..c3397787aecf 100644 --- a/examples/tutorial/sequence_parallel/model/layers/pooler.py +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from .linear import Linear diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py index dd66bfe13585..55dd20e1e948 100644 --- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -6,7 +6,6 @@ class PreProcessor(nn.Module): - def __init__(self, sub_seq_length): super().__init__() self.sub_seq_length = sub_seq_length @@ -15,10 +14,9 @@ def bert_position_ids(self, token_ids): # Create position ids seq_length = token_ids.size(1) local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - position_ids = torch.arange(seq_length * local_rank, - seq_length * (local_rank + 1), - dtype=torch.long, - device=token_ids.device) + position_ids = torch.arange( + seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device + ) position_ids = position_ids.unsqueeze(0).expand_as(token_ids) return position_ids @@ -42,7 +40,7 @@ def bert_extended_attention_mask(self, attention_mask): extended_attention_mask = attention_mask_bss.unsqueeze(1) # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) + extended_attention_mask = extended_attention_mask < 0.5 return extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index b8b89cda5525..e9ceb8d70cb8 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -12,7 +12,6 @@ from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.engine.schedule import PipelineSchedule from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import FusedAdam @@ -31,7 +30,7 @@ def process_batch_data(batch_data): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + parser.add_argument("-s", "--synthetic", action="store_true", help="whether use synthetic data") return parser.parse_args() @@ -48,37 +47,39 @@ def pipeline_data_process_func(stage_output, micro_batch_data): def main(): # initialize - args = parse_args() - colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + parse_args() + colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl") logger = get_dist_logger() # build synthetic dataloader BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) VOCAB_SIZE = 30528 - trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) - validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) + trainloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) + validloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) logger.info("Dataloaders are built", ranks=[0]) # build model - if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE: is_naive_fp16 = True else: is_naive_fp16 = False use_pipeline = is_using_pp() - kwargs = dict(vocab_size=VOCAB_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - max_sequence_length=gpc.config.SEQ_LENGTH, - num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, - convert_fp16_to_fp32_in_softmax=True, - is_naive_fp16=is_naive_fp16, - add_binary_head=gpc.config.ADD_BINARY_HEAD) + kwargs = dict( + vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD, + ) if use_pipeline: model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) @@ -99,35 +100,39 @@ def main(): logger.info("Criterion is built", ranks=[0]) # layernorm and bias has no weight decay - weight_decay_params = {'params': []} - no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + weight_decay_params = {"params": []} + no_weight_decay_params = {"params": [], "weight_decay": 0.0} for module_ in model.modules(): if isinstance(module_, LayerNorm): - no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + no_weight_decay_params["params"].extend([p for p in list(module_._parameters.values()) if p is not None]) else: - weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) - no_weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"] + ) + no_weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == "bias"] + ) logger.info( f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" ) # optimizer - optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), - lr=gpc.config.LR, - weight_decay=gpc.config.WEIGHT_DECAY) + optimizer = FusedAdam( + (weight_decay_params, no_weight_decay_params), lr=gpc.config.LR, weight_decay=gpc.config.WEIGHT_DECAY + ) logger.info("Optimizer is built", ranks=[0]) # lr scheduler # follow Megatron-LM setting warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) - lr_scheduler = AnnealingLR(optimizer=optimizer, - max_lr=gpc.config.LR, - min_lr=gpc.config.MIN_LR, - warmup_steps=warmup_steps, - decay_steps=gpc.config.DECAY_ITERS, - decay_style='linear') + lr_scheduler = AnnealingLR( + optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style="linear", + ) logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") # # init @@ -135,7 +140,6 @@ def main(): # build timer timer = MultiTimer() - skip_iters = 0 # build loss tracker accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() @@ -150,7 +154,7 @@ def main(): logger.info("start training") for step in range(1, gpc.config.TRAIN_ITERS + 1): - timer.start('train-iterations') + timer.start("train-iterations") engine.train() if use_pipeline: engine.zero_grad() @@ -158,13 +162,14 @@ def main(): engine.step() else: tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - trainloader) + trainloader + ) engine.zero_grad() lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) engine.backward(train_loss) engine.step() - timer.stop('train-iterations', keep_in_history=True) + timer.stop("train-iterations", keep_in_history=True) if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_train_loss += train_loss @@ -177,12 +182,18 @@ def main(): for j in range(gpc.config.EVAL_ITERS): with torch.no_grad(): if use_pipeline: - _, _, eval_loss = engine.execute_schedule(valid_data_iter, - forward_only=True, - return_output_label=False) + _, _, eval_loss = engine.execute_schedule( + valid_data_iter, forward_only=True, return_output_label=False + ) else: - tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - validloader) + ( + tokens, + types, + sentence_order, + loss_mask, + lm_labels, + padding_mask, + ) = get_batch_for_sequence_parallel(validloader) lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) @@ -196,18 +207,22 @@ def main(): timer_string = [] for n, t in timer: timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") - timer_string = ' | '.join(timer_string) - lr = list(engine.optimizer.param_groups)[0]['lr'] + timer_string = " | ".join(timer_string) + lr = list(engine.optimizer.param_groups)[0]["lr"] loss_scale = engine.optimizer.optim.loss_scale.item() if gpc.is_initialized(ParallelMode.PIPELINE): ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] else: ranks = [0] - logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + - f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + - f"| Learning rate: {lr} | " + timer_string, - ranks=ranks) + logger.info( + f"Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} " + + f"| Eval Loss: {accumulated_eval_loss.item():.5g} " + + f"| Loss Scale: {loss_scale}" + + f"| Learning rate: {lr} | " + + timer_string, + ranks=ranks, + ) for n, t in timer: t.reset() @@ -215,5 +230,5 @@ def main(): accumulated_train_loss.zero_() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 5ae7223b8c69..808559ec9c2d 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -7,17 +7,26 @@ from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder ALL_OPS = { - 'cpu_adam': CPUAdamBuilder, - 'fused_optim': FusedOptimBuilder, - 'moe': MOEBuilder, - 'multi_head_attn': MultiHeadAttnBuilder, - 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, - 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, - 'layernorm': LayerNormBuilder, + "cpu_adam": CPUAdamBuilder, + "fused_optim": FusedOptimBuilder, + "moe": MOEBuilder, + "multi_head_attn": MultiHeadAttnBuilder, + "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, + "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, + "layernorm": LayerNormBuilder, } __all__ = [ - 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', - 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', - 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' + "ALL_OPS", + "CPUAdamBuilder", + "FusedOptimBuilder", + "MultiHeadAttnBuilder", + "ScaledMaskedSoftmaxBuilder", + "ScaledUpperTrainglemaskedSoftmaxBuilder", + "MOEBuilder", + "MultiTensorSGDBuilder", + "MultiTensorAdamBuilder", + "MultiTensorLambBuilder", + "MultiTensorScaleBuilder", + "MultiTensorL2NormBuilder", ] diff --git a/op_builder/builder.py b/op_builder/builder.py index 8396235e5cfe..75823ef105c7 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -24,13 +24,14 @@ class Builder(ABC): def __init__(self, name: str, prebuilt_import_path: str): self.name = name self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # we store the op as an attribute to avoid repeated building and loading self.cached_op_module = None - assert prebuilt_import_path.startswith('colossalai._C'), \ - f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + assert prebuilt_import_path.startswith( + "colossalai._C" + ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" def relative_to_abs_path(self, code_path: str) -> str: """ @@ -46,10 +47,10 @@ def relative_to_abs_path(self, code_path: str) -> str: # this symlink will be replaced with actual files if we install via pypi # thus we cannot tell the colossalai root directory by checking whether the op_builder # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): root_path = op_builder_module_path.parent.parent else: - root_path = op_builder_module_path.parent.joinpath('colossalai') + root_path = op_builder_module_path.parent.joinpath("colossalai") code_abs_path = root_path.joinpath(code_path) return str(code_abs_path) @@ -59,13 +60,14 @@ def get_cuda_home_include(self): return include path inside the cuda home. """ from torch.utils.cpp_extension import CUDA_HOME + if CUDA_HOME is None: raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) # functions must be overrided begin @abstractmethod @@ -80,27 +82,24 @@ def include_dirs(self) -> List[str]: """ This function should return a list of include files for extensions. """ - pass @abstractmethod def cxx_flags(self) -> List[str]: """ This function should return a list of cxx compilation flags for extensions. """ - pass @abstractmethod def nvcc_flags(self) -> List[str]: """ This function should return a list of nvcc compilation flags for extensions. """ - pass # functions must be overrided over def strip_empty_entries(self, args): - ''' + """ Drop any empty strings from the list of compile and link flags - ''' + """ return [x for x in args if len(x) > 0] def import_op(self): @@ -114,8 +113,8 @@ def check_runtime_build_environment(self): Check whether the system environment is ready for extension compilation. """ try: - import torch from torch.utils.cpp_extension import CUDA_HOME + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -123,7 +122,8 @@ def check_runtime_build_environment(self): if not TORCH_AVAILABLE: raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions") + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" + ) if CUDA_HOME is None: raise RuntimeError( @@ -150,7 +150,7 @@ def load(self, verbose: Optional[bool] = None): verbose (bool, optional): show detailed info. Defaults to True. """ if verbose is None: - verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1' + verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" # if the kernel has be compiled and cached, we directly use it if self.cached_op_module is not None: return self.cached_op_module @@ -161,7 +161,8 @@ def load(self, verbose: Optional[bool] = None): op_module = self.import_op() if verbose: print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building.") + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." + ) except ImportError: # check environment self.check_runtime_build_environment() @@ -172,10 +173,11 @@ def load(self, verbose: Optional[bool] = None): # construct the build directory import torch from torch.utils.cpp_extension import load - torch_version_major = torch.__version__.split('.')[0] - torch_version_minor = torch.__version__.split('.')[1] + + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser('~') + home_directory = os.path.expanduser("~") extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" build_directory = os.path.join(home_directory, extension_directory) Path(build_directory).mkdir(parents=True, exist_ok=True) @@ -184,14 +186,16 @@ def load(self, verbose: Optional[bool] = None): print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") # load the kernel - op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose) + op_module = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose, + ) build_duration = time.time() - start_build @@ -204,16 +208,18 @@ def load(self, verbose: Optional[bool] = None): return op_module - def builder(self) -> 'CUDAExtension': + def builder(self) -> "CUDAExtension": """ get a CUDAExtension instance used for setup.py """ from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension(name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - 'cxx': self.strip_empty_entries(self.cxx_flags()), - 'nvcc': self.strip_empty_entries(self.nvcc_flags()) - }) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 500e2cc0eddc..5a2a2e3e6a56 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -10,29 +8,29 @@ class CPUAdamBuilder(Builder): def __init__(self): super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path('cpu_adam.cpp'), + self.csrc_abs_path("cpu_adam.cpp"), ] return ret def include_dirs(self): - return [ - self.csrc_abs_path("includes"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] def cxx_flags(self): - extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + extra_cxx_flags = ["-std=c++14", "-lcudart", "-lcublas", "-g", "-Wno-reorder", "-fopenmp", "-march=native"] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py index 31ddfced1db2..3baa0880d801 100644 --- a/op_builder/fused_optim.py +++ b/op_builder/fused_optim.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import get_cuda_cc_flag @@ -10,25 +8,30 @@ class FusedOptimBuilder(Builder): def __init__(self): super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - + def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', - 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + self.csrc_abs_path(fname) + for fname in [ + "colossal_C_frontend.cpp", + "multi_tensor_sgd_kernel.cu", + "multi_tensor_scale_kernel.cu", + "multi_tensor_adam.cu", + "multi_tensor_l2norm_kernel.cu", + "multi_tensor_lamb.cu", ] ] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - return ['-O3'] + version_dependent_macros + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags = ["-lineinfo"] extra_cuda_flags.extend(get_cuda_cc_flag()) - return ['-O3', '--use_fast_math'] + extra_cuda_flags + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py index 61d941741929..2684c6ddb7f7 100644 --- a/op_builder/layernorm.py +++ b/op_builder/layernorm.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -12,18 +10,18 @@ def __init__(self): super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags = ["-maxrregcount=50"] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py index eeb7d8e3980c..6f8028b1720c 100644 --- a/op_builder/moe.py +++ b/op_builder/moe.py @@ -1,11 +1,8 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MOEBuilder(Builder): - NAME = "moe" PREBUILT_IMPORT_PATH = "colossalai._C.moe" @@ -13,24 +10,23 @@ def __init__(self): super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - ret = [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index f9103fe94729..b70f041db7d6 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -1,18 +1,13 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, - prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - + super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] @@ -20,22 +15,31 @@ def include_dirs(self): def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', - 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', - 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + self.csrc_abs_path(fname) + for fname in [ + "multihead_attention_1d.cpp", + "kernels/cublas_wrappers.cu", + "kernels/transform_kernels.cu", + "kernels/dropout_kernels.cu", + "kernels/normalize_kernels.cu", + "kernels/softmax_kernels.cu", + "kernels/general_kernels.cu", + "kernels/cuda_util.cu", ] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py index 11cfda39a85c..b2f1de7792c8 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/op_builder/scaled_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -9,29 +7,28 @@ class ScaledMaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" def __init__(self): - super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH + ) # necessary 4 functions def sources_files(self): - ret = [ - self.csrc_abs_path(fname) for fname in - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] - ] + ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] return ret def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py index d0d2433aa645..1445230acbc1 100644 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -9,29 +7,31 @@ class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" def __init__(self): - super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, + prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, + ) def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def sources_files(self): ret = [ self.csrc_abs_path(fname) - for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py index 9412c725baab..3f75f952d57b 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -11,6 +11,7 @@ def print_rank_0(message: str) -> None: """ try: import torch.distributed as dist + if not dist.is_initialized(): is_main_rank = True else: @@ -36,7 +37,8 @@ def get_cuda_version_in_pytorch() -> List[int]: torch_cuda_minor = torch.version.cuda.split(".")[1] except: raise ValueError( - "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda") + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) return torch_cuda_major, torch_cuda_minor @@ -50,7 +52,7 @@ def get_cuda_bare_metal_version(cuda_dir) -> List[int]: Returns: The CUDA version required by PyTorch, in the form of tuple (major, minor). """ - nvcc_path = os.path.join(cuda_dir, 'bin/nvcc') + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") if cuda_dir is None: raise ValueError( @@ -85,9 +87,9 @@ def check_system_pytorch_cuda_match(cuda_dir): if bare_metal_major != torch_cuda_major: raise Exception( - f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) ' - f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).' - 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." ) if bare_metal_minor != torch_cuda_minor: @@ -107,10 +109,11 @@ def get_pytorch_version() -> List[int]: A tuple of integers in the form of (major, minor, patch). """ import torch - torch_version = torch.__version__.split('+')[0] - TORCH_MAJOR = int(torch_version.split('.')[0]) - TORCH_MINOR = int(torch_version.split('.')[1]) - TORCH_PATCH = int(torch_version.split('.')[2], 16) + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH @@ -132,7 +135,8 @@ def check_pytorch_version(min_major_version, min_minor_version) -> bool: if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): raise RuntimeError( f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" - "The latest stable release can be obtained from https://pytorch.org/get-started/locally/") + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) def check_cuda_availability(): @@ -143,6 +147,7 @@ def check_cuda_availability(): A boolean value. True if CUDA is available and False otherwise. """ import torch + return torch.cuda.is_available() @@ -155,29 +160,31 @@ def set_cuda_arch_list(cuda_dir): # we only need to set this when CUDA is not available for cross-compilation if not cuda_available: - warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Colossal-AI will cross-compile for \n' - '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - '2. Volta (compute capability 7.0)\n' - '3. Turing (compute capability 7.5),\n' - '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' - '\nIf you wish to cross-compile for a single specific architecture,\n' - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5'] + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] if int(bare_metal_major) == 11: if int(bare_metal_minor) == 0: - arch_list.append('8.0') + arch_list.append("8.0") else: - arch_list.append('8.0') - arch_list.append('8.6') + arch_list.append("8.0") + arch_list.append("8.6") - arch_list_str = ';'.join(arch_list) + arch_list_str = ";".join(arch_list) os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str return False return True @@ -197,13 +204,13 @@ def get_cuda_cc_flag() -> List[str]: import torch cc_flag = [] - max_arch = ''.join(str(i) for i in torch.cuda.get_device_capability()) + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) + res = re.search(r"sm_(\d+)", arch) if res: arch_cap = res[1] if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) return cc_flag diff --git a/setup.py b/setup.py index 5d8f831218d9..cda1ba7ee7a6 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ ) try: - import torch from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -26,14 +26,14 @@ MIN_PYTORCH_VERSION_MAJOR = 1 MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 -IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 +BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 # a variable to store the op builder ext_modules = [] # we do not support windows currently -if sys.platform == 'win32': +if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") @@ -64,7 +64,7 @@ def fetch_requirements(path) -> List[str]: Returns: The lines in the requirements file. """ - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] @@ -75,7 +75,7 @@ def fetch_readme() -> str: Returns: The lines in the README file. """ - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() @@ -89,21 +89,21 @@ def get_version() -> str: setup_file_path = os.path.abspath(__file__) project_path = os.path.dirname(setup_file_path) - version_txt_path = os.path.join(project_path, 'version.txt') - version_py_path = os.path.join(project_path, 'colossalai/version.py') + version_txt_path = os.path.join(project_path, "version.txt") + version_py_path = os.path.join(project_path, "colossalai/version.py") with open(version_txt_path) as f: version = f.read().strip() # write version into version.py - with open(version_py_path, 'w') as f: + with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") # look for pytorch and cuda version if BUILD_CUDA_EXT: torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f'{torch_major}.{torch_minor}' - cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)) + torch_version = f"{torch_major}.{torch_minor}" + cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) else: torch_version = None cuda_version = None @@ -112,12 +112,12 @@ def get_version() -> str: if torch_version: f.write(f'torch = "{torch_version}"\n') else: - f.write('torch = None\n') + f.write("torch = None\n") if cuda_version: f.write(f'cuda = "{cuda_version}"\n') else: - f.write('cuda = None\n') + f.write("cuda = None\n") return version @@ -127,6 +127,7 @@ def get_version() -> str: set_cuda_arch_list(CUDA_HOME) from op_builder import ALL_OPS + op_names = [] # load all builders @@ -135,7 +136,7 @@ def get_version() -> str: ext_modules.append(builder_cls().builder()) # show log - op_name_list = ', '.join(op_names) + op_name_list = ", ".join(op_names) print(f"[extension] loaded builders for {op_name_list}") # always put not nightly branch as the if branch @@ -143,56 +144,62 @@ def get_version() -> str: # and it will mess up with the dependency graph insights if not IS_NIGHTLY: version = get_version() - package_name = 'colossalai' + package_name = "colossalai" else: # use date as the nightly version - version = datetime.today().strftime('%Y.%m.%d') - package_name = 'colossalai-nightly' - -setup(name=package_name, - version=version, - packages=find_packages(exclude=( - 'op_builder', - 'benchmark', - 'docker', - 'tests', - 'docs', - 'examples', - 'tests', - 'scripts', - 'requirements', - '*.egg-info', - )), - description='An integrated large-scale model training system with efficient parallelization techniques', - long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://www.colossalai.org', - project_urls={ - 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', - 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', - 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', - 'Documentation': 'http://colossalai.readthedocs.io', - 'Github': 'https://github.com/hpcaitech/ColossalAI', - }, - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - install_requires=fetch_requirements('requirements/requirements.txt'), - entry_points=''' + version = datetime.today().strftime("%Y.%m.%d") + package_name = "colossalai-nightly" + +setup( + name=package_name, + version=version, + packages=find_packages( + exclude=( + "op_builder", + "benchmark", + "docker", + "tests", + "docs", + "examples", + "tests", + "scripts", + "requirements", + "*.egg-info", + ) + ), + description="An integrated large-scale model training system with efficient parallelization techniques", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://www.colossalai.org", + project_urls={ + "Forum": "https://github.com/hpcaitech/ColossalAI/discussions", + "Bug Tracker": "https://github.com/hpcaitech/ColossalAI/issues", + "Examples": "https://github.com/hpcaitech/ColossalAI-Examples", + "Documentation": "http://colossalai.readthedocs.io", + "Github": "https://github.com/hpcaitech/ColossalAI", + }, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements("requirements/requirements.txt"), + entry_points=""" [console_scripts] colossalai=colossalai.cli:cli - ''', - python_requires='>=3.6', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', - ], - package_data={ - 'colossalai': [ - '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', - 'kernel/cuda_native/csrc/kernels/include/*' - ] - }) + """, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], + package_data={ + "colossalai": [ + "_C/*.pyi", + "kernel/cuda_native/csrc/*", + "kernel/cuda_native/csrc/kernel/*", + "kernel/cuda_native/csrc/kernels/include/*", + ] + }, +) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index f29efefce4a4..65eaa72d6e84 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -11,9 +11,19 @@ ) from .utils import run_fwd, run_fwd_bwd -from . import albert # isort:skip +from . import albert # isort:skip __all__ = [ - 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', - 'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd' + "bert", + "gpt2", + "hanging_param_model", + "inline_op_model", + "nested_model", + "repeated_computed_layers", + "resnet", + "simple_net", + "run_fwd_bwd", + "albert", + "beit", + "run_fwd", ] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py index 8924eb2fbc92..0ba4d19655cd 100644 --- a/tests/components_to_test/albert.py +++ b/tests/components_to_test/albert.py @@ -1,13 +1,11 @@ import torch -import transformers -from packaging import version from transformers import AlbertConfig, AlbertForSequenceClassification from .bert import get_bert_data_loader from .registry import non_distributed_component_funcs -@non_distributed_component_funcs.register(name='albert') +@non_distributed_component_funcs.register(name="albert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -16,20 +14,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = AlbertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building AlbertForSequenceClassification model') + config = AlbertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building AlbertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(AlbertForSequenceClassification): - def forward(self, input_ids, labels): """ inputs: data, label @@ -44,16 +43,20 @@ def forward(self, input_ids, labels): return model is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py index 2021ae6f6e35..d33474ea9a6b 100644 --- a/tests/components_to_test/beit.py +++ b/tests/components_to_test/beit.py @@ -14,25 +14,27 @@ class DummyDataLoader(DummyDataGenerator): batch_size = 4 def generate(self): - data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, - DummyDataLoader.img_size), - device=get_current_device()) - label = torch.randint(low=0, - high=DummyDataLoader.num_class, - size=(DummyDataLoader.batch_size,), - device=get_current_device()) + data = torch.randn( + ( + DummyDataLoader.batch_size, + DummyDataLoader.num_channel, + DummyDataLoader.img_size, + DummyDataLoader.img_size, + ), + device=get_current_device(), + ) + label = torch.randint( + low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() + ) return data, label -@non_distributed_component_funcs.register(name='beit') +@non_distributed_component_funcs.register(name="beit") def get_training_components(): - def model_builder(checkpoint=False): - model = Beit(img_size=DummyDataLoader.img_size, - num_classes=DummyDataLoader.num_class, - embed_dim=32, - depth=2, - num_heads=4) + model = Beit( + img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 + ) return model trainloader = DummyDataLoader() diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index e7d1d50806b8..f0061ad18c84 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -8,12 +8,12 @@ def get_bert_data_loader( - n_class, - batch_size, - total_samples, - sequence_length, - device=torch.device('cpu:0'), - is_distributed=False, + n_class, + batch_size, + total_samples, + sequence_length, + device=torch.device("cpu:0"), + is_distributed=False, ): train_data = torch.randint( low=0, @@ -32,7 +32,7 @@ def get_bert_data_loader( return train_loader -@non_distributed_component_funcs.register(name='bert') +@non_distributed_component_funcs.register(name="bert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -41,20 +41,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = BertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building BertForSequenceClassification model') + config = BertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building BertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(BertForSequenceClassification): - def forward(self, input_ids, labels): """ inputs: data, label @@ -69,16 +70,20 @@ def forward(self, input_ids, labels): return model is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py index fe25b4923fa2..7f826497d2ab 100644 --- a/tests/components_to_test/gpt2.py +++ b/tests/components_to_test/gpt2.py @@ -14,33 +14,40 @@ class DummyDataLoader(DummyDataGenerator): seq_len = 64 def generate(self): - input_ids = torch.randint(0, - DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), - device=get_current_device()) + input_ids = torch.randint( + 0, + DummyDataLoader.vocab_size, + (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device(), + ) return input_ids, input_ids class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -51,12 +58,9 @@ def forward(self, input_ids): def gpt2_micro(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint, - hidden_size=32, - num_layers=2, - num_attention_heads=4, - max_seq_len=64, - vocab_size=128) + return GPTLMModel( + checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128 + ) def gpt2_s(checkpoint=True): @@ -68,7 +72,6 @@ def gpt2_m(checkpoint=True): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -80,9 +83,8 @@ def forward(self, logits, labels): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -@non_distributed_component_funcs.register(name='gpt2') +@non_distributed_component_funcs.register(name="gpt2") def get_training_components(): - trainloader = DummyDataLoader() testloader = DummyDataLoader() diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py index 0e65431217c7..5531c8d081a0 100644 --- a/tests/components_to_test/hanging_param_model.py +++ b/tests/components_to_test/hanging_param_model.py @@ -28,16 +28,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='hanging_param_model') +@non_distributed_component_funcs.register(name="hanging_param_model") def get_training_components(): - def model_builder(checkpoint=False): return HangingParamModule(checkpoint) @@ -46,4 +44,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py index 80757f361d9e..8bfa9cf34353 100644 --- a/tests/components_to_test/inline_op_model.py +++ b/tests/components_to_test/inline_op_model.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from colossalai.legacy.nn import CheckpointModule @@ -19,7 +18,6 @@ def __init__(self, checkpoint=False) -> None: self.proj2 = nn.Linear(8, 8) def forward(self, x): - x = self.proj1(x) # inline add_ x.add_(10) @@ -31,16 +29,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='inline_op_model') +@non_distributed_component_funcs.register(name="inline_op_model") def get_training_components(): - def model_builder(checkpoint=False): return InlineOpModule(checkpoint) @@ -49,4 +45,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 3e779b0a6428..44577456dec5 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -9,7 +9,6 @@ class SubNet(nn.Module): - def __init__(self, out_features) -> None: super().__init__() self.bias = nn.Parameter(torch.zeros(out_features)) @@ -19,7 +18,6 @@ def forward(self, x, weight): class NestedNet(CheckpointModule): - def __init__(self, checkpoint=False) -> None: super().__init__(checkpoint) self.fc1 = nn.Linear(5, 5) @@ -35,16 +33,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='nested_model') +@non_distributed_component_funcs.register(name="nested_model") def get_training_components(): - def model_builder(checkpoint=False): return NestedNet(checkpoint) diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py index edfcaaa7275b..ec561b7831ad 100644 --- a/tests/components_to_test/registry.py +++ b/tests/components_to_test/registry.py @@ -2,7 +2,6 @@ class Registry: - def __init__(self): self._registry = dict() @@ -36,4 +35,4 @@ def __next__(self): non_distributed_component_funcs = Registry() model_parallel_component_funcs = Registry() -__all__ = ['non_distributed_component_funcs', 'model_parallel_component_funcs'] +__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py index c1ef99aa07b4..3da64de3fb64 100644 --- a/tests/components_to_test/repeated_computed_layers.py +++ b/tests/components_to_test/repeated_computed_layers.py @@ -29,16 +29,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='repeated_computed_layers') +@non_distributed_component_funcs.register(name="repeated_computed_layers") def get_training_components(): - def model_builder(checkpoint=False): return NetWithRepeatedlyComputedLayers(checkpoint) diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index df01e4c4847e..a43becc16233 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -13,19 +13,20 @@ def get_cifar10_dataloader(train): # build dataloaders - dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - train=train, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])) + dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + train=train, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) return dataloader -@non_distributed_component_funcs.register(name='resnet18') +@non_distributed_component_funcs.register(name="resnet18") def get_resnet_training_components(): - def model_builder(checkpoint=False): return resnet18(num_classes=10) diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index 064974a15a97..0f0ac5cff49a 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -33,16 +33,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) return data, label -@non_distributed_component_funcs.register(name='simple_net') +@non_distributed_component_funcs.register(name="simple_net") def get_training_components(): - def model_builder(checkpoint=False): return SimpleNet(checkpoint) @@ -51,4 +49,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py index 5ab33e86de23..7b3af46c8f35 100644 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -2,7 +2,6 @@ class DummyDataGenerator(ABC): - def __init__(self, length=10): self.length = length diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 466a2a558829..c08fd365d871 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers from .registry import model_zoo -__all__ = ['model_zoo'] +__all__ = ["model_zoo"] diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py index 204c1d7773ca..895ee7967f6b 100644 --- a/tests/kit/model_zoo/diffusers/diffusers.py +++ b/tests/kit/model_zoo/diffusers/diffusers.py @@ -4,7 +4,7 @@ import torch import transformers -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH_SIZE = 2 SEQ_LENGTH = 5 @@ -26,10 +26,9 @@ def data_clip_model(): attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) - return dict(input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids) + return dict( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids + ) def data_clip_text(): @@ -43,32 +42,41 @@ def data_clip_vision(): return dict(pixel_values=pixel_values) -model_zoo.register(name='diffusers_auto_encoder_kl', - model_fn=diffusers.AutoencoderKL, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_vq_model', - model_fn=diffusers.VQModel, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_model', - model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), - data_gen_fn=data_clip_model, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_text_model', - model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), - data_gen_fn=data_clip_text, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_vision_model', - model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), - data_gen_fn=data_clip_vision, - output_transform_fn=clip_vision_model_output) - -model_zoo.register(name='diffusers_unet2d_model', - model_fn=diffusers.UNet2DModel, - data_gen_fn=data_unet_fn, - output_transform_fn=identity_output) +model_zoo.register( + name="diffusers_auto_encoder_kl", + model_fn=diffusers.AutoencoderKL, + data_gen_fn=data_vae_fn, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_vq_model", model_fn=diffusers.VQModel, data_gen_fn=data_vae_fn, output_transform_fn=identity_output +) + +model_zoo.register( + name="diffusers_clip_model", + model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), + data_gen_fn=data_clip_model, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_text_model", + model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), + data_gen_fn=data_clip_text, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_vision_model", + model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), + data_gen_fn=data_clip_vision, + output_transform_fn=clip_vision_model_output, +) + +model_zoo.register( + name="diffusers_unet2d_model", + model_fn=diffusers.UNet2DModel, + data_gen_fn=data_unet_fn, + output_transform_fn=identity_output, +) diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 1e7ef3b62736..b90972291870 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Callable -__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo'] +__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"] @dataclass @@ -14,6 +14,7 @@ class ModelAttribute: has_control_flow (bool): Whether the model contains branching in its forward method. has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models. """ + has_control_flow: bool = False has_stochastic_depth_prob: bool = False @@ -23,13 +24,15 @@ class ModelZooRegistry(dict): A registry to map model names to model and data generation functions. """ - def register(self, - name: str, - model_fn: Callable, - data_gen_fn: Callable, - output_transform_fn: Callable, - loss_fn: Callable = None, - model_attribute: ModelAttribute = None): + def register( + self, + name: str, + model_fn: Callable, + data_gen_fn: Callable, + output_transform_fn: Callable, + loss_fn: Callable = None, + model_attribute: ModelAttribute = None, + ): """ Register a model and data generation function. @@ -71,7 +74,7 @@ def get_sub_registry(self, keyword: str): if keyword in k: new_dict[k] = v - assert len(new_dict) > 0, f'No model found with keyword {keyword}' + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py index b29ac12a6b53..eb6d2f6bc757 100644 --- a/tests/kit/model_zoo/timm/timm.py +++ b/tests/kit/model_zoo/timm/timm.py @@ -9,151 +9,183 @@ data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) output_transform_fn = lambda x: dict(output=x) -model_zoo.register(name='timm_resnet', - model_fn=tm.resnest.resnest50d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beit', - model_fn=tm.beit.beit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_cait', - model_fn=tm.cait.cait_s24_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convmixer', - model_fn=tm.convmixer.convmixer_768_32, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientnetv2', - model_fn=tm.efficientnet.efficientnetv2_m, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_resmlp', - model_fn=tm.resmlp_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_vision_transformer', - model_fn=tm.vision_transformer.vit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_deit', - model_fn=tm.deit_base_distilled_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beitv2', - model_fn=tm.beitv2_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_coat', - model_fn=tm.coat.coat_lite_mini, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_resnet", model_fn=tm.resnest.resnest50d, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_beit", + model_fn=tm.beit.beit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_cait", model_fn=tm.cait.cait_s24_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_convmixer", + model_fn=tm.convmixer.convmixer_768_32, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_efficientnetv2", + model_fn=tm.efficientnet.efficientnetv2_m, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_resmlp", model_fn=tm.resmlp_12_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_vision_transformer", + model_fn=tm.vision_transformer.vit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_deit", + model_fn=tm.deit_base_distilled_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_beitv2", + model_fn=tm.beitv2_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_coat", model_fn=tm.coat.coat_lite_mini, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) -model_zoo.register(name='timm_deit3', - model_fn=tm.deit3_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_deit3", + model_fn=tm.deit3_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='timm_eca_nfnet', - model_fn=tm.eca_nfnet_l0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientformer', - model_fn=tm.efficientformer_l1, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_ese_vovnet19b_dw', - model_fn=tm.ese_vovnet19b_dw, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmixer_12_224', - model_fn=tm.gmixer_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmlp_b16_224', - model_fn=tm.gmlp_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hardcorenas_a', - model_fn=tm.hardcorenas_a, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hrnet_w18_small', - model_fn=tm.hrnet_w18_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_mixer_b16_224', - model_fn=tm.mixer_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_ecaresnet101', - model_fn=tm.nf_ecaresnet101, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_regnet_b0', - model_fn=tm.nf_regnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_regnetv_040', - model_fn=tm.regnetv_040, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_skresnet18', - model_fn=tm.skresnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_tnt_b_patch16_224', - model_fn=tm.tnt_b_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convit', - model_fn=tm.convit_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_dm_nfnet', - model_fn=tm.dm_nfnet_f0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_eca_nfnet", model_fn=tm.eca_nfnet_l0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_efficientformer", + model_fn=tm.efficientformer_l1, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_ese_vovnet19b_dw", + model_fn=tm.ese_vovnet19b_dw, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmixer_12_224", + model_fn=tm.gmixer_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmlp_b16_224", model_fn=tm.gmlp_b16_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_hardcorenas_a", + model_fn=tm.hardcorenas_a, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_hrnet_w18_small", + model_fn=tm.hrnet_w18_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_inception_v3", model_fn=tm.inception_v3, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_mixer_b16_224", + model_fn=tm.mixer_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_ecaresnet101", + model_fn=tm.nf_ecaresnet101, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_regnet_b0", model_fn=tm.nf_regnet_b0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_regnetv_040", model_fn=tm.regnetv_040, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_skresnet18", model_fn=tm.skresnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_tnt_b_patch16_224", + model_fn=tm.tnt_b_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_convit", model_fn=tm.convit_base, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_dm_nfnet", model_fn=tm.dm_nfnet_f0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) # ============== # Register models with control flow # ============== -model_zoo.register(name='timm_convnext', - model_fn=tm.convnext.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_vgg', - model_fn=tm.vgg.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_dpn', - model_fn=tm.dpn.dpn68, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_densenet', - model_fn=tm.densenet.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_rexnet', - model_fn=tm.rexnet.rexnet_100, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_swin_transformer', - model_fn=tm.swin_transformer.swin_base_patch4_window7_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="timm_convnext", + model_fn=tm.convnext.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_vgg", + model_fn=tm.vgg.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_dpn", + model_fn=tm.dpn.dpn68, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_densenet", + model_fn=tm.densenet.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_rexnet", + model_fn=tm.rexnet.rexnet_100, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_swin_transformer", + model_fn=tm.swin_transformer.swin_base_patch4_window7_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py index 9a244ac312c0..03f565c04553 100644 --- a/tests/kit/model_zoo/torchaudio/torchaudio.py +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -23,24 +23,31 @@ def conformer_data_gen_fn(): transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) -model_zoo.register(name='torchaudio_conformer', - model_fn=lambda: tm.Conformer( - input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31), - data_gen_fn=conformer_data_gen_fn, - output_transform_fn=transformer_output_transform_fn) +model_zoo.register( + name="torchaudio_conformer", + model_fn=lambda: tm.Conformer( + input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31 + ), + data_gen_fn=conformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, +) single_output_transform_fn = lambda output: dict(output=output) -model_zoo.register(name='torchaudio_convtasnet', - model_fn=tm.ConvTasNet, - data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_convtasnet", + model_fn=tm.ConvTasNet, + data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_deepspeech', - model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), - data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_deepspeech", + model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), + data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), + output_transform_fn=single_output_transform_fn, +) def emformer_data_gen_fn(): @@ -50,21 +57,26 @@ def emformer_data_gen_fn(): model_zoo.register( - name='torchaudio_emformer', + name="torchaudio_emformer", model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), data_gen_fn=emformer_data_gen_fn, output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_wav2letter_waveform', - model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_waveform", + model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) -model_zoo.register(name='torchaudio_wav2letter_mfcc', - model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_mfcc", + model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) def wavernn_data_gen_fn(): @@ -73,20 +85,24 @@ def wavernn_data_gen_fn(): return dict(waveform=waveform, specgram=specgram) -model_zoo.register(name='torchaudio_wavernn', - model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5], - n_classes=N_CLASSES, - hop_length=HOP_LENGTH, - kernel_size=KERNEL_SIZE, - n_freq=N_FREQ, - n_res_block=2, - n_rnn=64, - n_fc=64, - n_hidden=16, - n_output=16), - data_gen_fn=wavernn_data_gen_fn, - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wavernn", + model_fn=lambda: tm.WaveRNN( + upsample_scales=[2, 2, 5], + n_classes=N_CLASSES, + hop_length=HOP_LENGTH, + kernel_size=KERNEL_SIZE, + n_freq=N_FREQ, + n_res_block=2, + n_rnn=64, + n_fc=64, + n_hidden=16, + n_output=16, + ), + data_gen_fn=wavernn_data_gen_fn, + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) def tacotron_data_gen_fn(): @@ -97,17 +113,18 @@ def tacotron_data_gen_fn(): token_lengths = max_text_length * torch.ones((n_batch,)) mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) - return dict(tokens=tokens, - token_lengths=token_lengths, - mel_specgram=mel_specgram, - mel_specgram_lengths=mel_specgram_lengths) + return dict( + tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths + ) -model_zoo.register(name='torchaudio_tacotron', - model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), - data_gen_fn=tacotron_data_gen_fn, - output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_tacotron", + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), + model_attribute=ModelAttribute(has_control_flow=True), +) def wav2vec_data_gen_fn(): @@ -117,14 +134,18 @@ def wav2vec_data_gen_fn(): return dict(waveforms=waveforms, lengths=lengths) -model_zoo.register(name='torchaudio_wav2vec2_base', - model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wav2vec2_base", + model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_hubert_base', - model_fn=tm.hubert_base, - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_hubert_base", + model_fn=tm.hubert_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index dda563155fca..d4baf576d54b 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import torch @@ -7,7 +6,7 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH = 2 SHAPE = 10 @@ -20,9 +19,9 @@ def gen_kt(): # KeyedJaggedTensor def gen_kjt(): - KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) + KJT = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8]) + ) return KJT @@ -68,7 +67,7 @@ def get_ebc(): # EmbeddingBagCollection eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) - return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu")) def sparse_arch_model_fn(): @@ -91,52 +90,69 @@ def dlrm_sparsearch_model_fn(): return dlrm.SparseArch(ebc) -model_zoo.register(name='deepfm_densearch', - model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_interactionarch', - model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_overarch', - model_fn=partial(deepfm.OverArch, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_simpledeepfmnn', - model_fn=simple_deep_fmnn_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_sparsearch', - model_fn=sparse_arch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm', - model_fn=dlrm_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_densearch', - model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_interactionarch', - model_fn=partial(dlrm.InteractionArch, 2), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_overarch', - model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_sparsearch', - model_fn=dlrm_sparsearch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="deepfm_densearch", + model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_interactionarch", + model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_overarch", + model_fn=partial(deepfm.OverArch, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_simpledeepfmnn", + model_fn=simple_deep_fmnn_model_fn, + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_sparsearch", + model_fn=sparse_arch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn +) + +model_zoo.register( + name="dlrm_densearch", + model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_interactionarch", + model_fn=partial(dlrm.InteractionArch, 2), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_overarch", + model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_sparsearch", + model_fn=dlrm_sparsearch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index ddc3ec24b2ff..57b633e9d676 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -1,5 +1,3 @@ -from collections import namedtuple - import torch import torchvision import torchvision.models as tm @@ -29,103 +27,133 @@ def swin_s(): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=[7, 7], - stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic + stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic weights=weights, progress=progress, ) # special output transform fn -google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs - ) else dict(output=x) -swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val - for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) -inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs - ) else dict(output=x) +google_net_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) +) +swin_s_output_output_transform_fn = ( + lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +) +inception_v3_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) +) -model_zoo.register(name='torchvision_alexnet', - model_fn=tm.alexnet, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_densenet121', - model_fn=tm.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_efficientnet_b0', - model_fn=tm.efficientnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) -model_zoo.register(name='torchvision_googlenet', - model_fn=tm.googlenet, - data_gen_fn=data_gen_fn, - output_transform_fn=google_net_output_transform_fn) -model_zoo.register(name='torchvision_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=inception_v3_data_gen_fn, - output_transform_fn=inception_v3_output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v2', - model_fn=tm.mobilenet_v2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v3_small', - model_fn=tm.mobilenet_v3_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mnasnet0_5', - model_fn=tm.mnasnet0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnet18', - model_fn=tm.resnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_regnet_x_16gf', - model_fn=tm.regnet_x_16gf, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnext50_32x4d', - model_fn=tm.resnext50_32x4d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_shufflenet_v2_x0_5', - model_fn=tm.shufflenet_v2_x0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_squeezenet1_0', - model_fn=tm.squeezenet1_0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_alexnet", model_fn=tm.alexnet, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_densenet121", + model_fn=tm.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_efficientnet_b0", + model_fn=tm.efficientnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), +) +model_zoo.register( + name="torchvision_googlenet", + model_fn=tm.googlenet, + data_gen_fn=data_gen_fn, + output_transform_fn=google_net_output_transform_fn, +) +model_zoo.register( + name="torchvision_inception_v3", + model_fn=tm.inception_v3, + data_gen_fn=inception_v3_data_gen_fn, + output_transform_fn=inception_v3_output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v2", + model_fn=tm.mobilenet_v2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v3_small", + model_fn=tm.mobilenet_v3_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mnasnet0_5", + model_fn=tm.mnasnet0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnet18", model_fn=tm.resnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_regnet_x_16gf", + model_fn=tm.regnet_x_16gf, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnext50_32x4d", + model_fn=tm.resnext50_32x4d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_shufflenet_v2_x0_5", + model_fn=tm.shufflenet_v2_x0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_squeezenet1_0", + model_fn=tm.squeezenet1_0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='torchvision_vgg11', - model_fn=tm.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_vgg11", model_fn=tm.vgg11, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -if version.parse(torchvision.__version__) >= version.parse('0.12.0'): - model_zoo.register(name='torchvision_vit_b_16', - model_fn=tm.vit_b_16, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - model_zoo.register(name='torchvision_convnext_base', - model_fn=tm.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) +if version.parse(torchvision.__version__) >= version.parse("0.12.0"): + model_zoo.register( + name="torchvision_vit_b_16", + model_fn=tm.vit_b_16, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + ) + model_zoo.register( + name="torchvision_convnext_base", + model_fn=tm.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) -if version.parse(torchvision.__version__) >= version.parse('0.13.0'): +if version.parse(torchvision.__version__) >= version.parse("0.13.0"): model_zoo.register( - name='torchvision_swin_s', + name="torchvision_swin_s", model_fn=swin_s, data_gen_fn=data_gen_fn, output_transform_fn=swin_s_output_output_transform_fn, ) - model_zoo.register(name='torchvision_efficientnet_v2_s', - model_fn=tm.efficientnet_v2_s, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) + model_zoo.register( + name="torchvision_efficientnet_v2_s", + model_fn=tm.efficientnet_v2_s, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py index 70f9ee11ad6e..d1c23703b3e4 100644 --- a/tests/kit/model_zoo/transformers/albert.py +++ b/tests/kit/model_zoo/transformers/albert.py @@ -19,44 +19,52 @@ def data_gen_fn(): def data_gen_for_pretrain(): inputs = data_gen_fn() - inputs['labels'] = inputs['input_ids'].clone() - inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64) + inputs["labels"] = inputs["input_ids"].clone() + inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64) return inputs output_transform_fn = lambda x: x -config = transformers.AlbertConfig(embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) - -model_zoo.register(name='transformers_albert', - model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_pretraining', - model_fn=lambda: transformers.AlbertForPreTraining(config), - data_gen_fn=data_gen_for_pretrain, - output_transform_fn=lambda x: dict(loss=x.loss), - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_masked_lm', - model_fn=lambda: transformers.AlbertForMaskedLM(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_sequence_classification', - model_fn=lambda: transformers.AlbertForSequenceClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_token_classification', - model_fn=lambda: transformers.AlbertForTokenClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +config = transformers.AlbertConfig( + embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256 +) + +model_zoo.register( + name="transformers_albert", + model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_pretraining", + model_fn=lambda: transformers.AlbertForPreTraining(config), + data_gen_fn=data_gen_for_pretrain, + output_transform_fn=lambda x: dict(loss=x.loss), + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_masked_lm", + model_fn=lambda: transformers.AlbertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_sequence_classification", + model_fn=lambda: transformers.AlbertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_token_classification", + model_fn=lambda: transformers.AlbertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) # =============================== # Register multi-sentence ALBERT @@ -80,13 +88,17 @@ def data_gen_for_mcq(): return encoding -model_zoo.register(name='transformers_albert_for_question_answering', - model_fn=lambda: transformers.AlbertForQuestionAnswering(config), - data_gen_fn=data_gen_for_qa, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_multiple_choice', - model_fn=lambda: transformers.AlbertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_albert_for_question_answering", + model_fn=lambda: transformers.AlbertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_multiple_choice", + model_fn=lambda: transformers.AlbertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 993c90b0abc2..8b90a3c7372c 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -28,7 +28,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -36,7 +36,7 @@ def data_gen_for_pretraining(): # pretraining data gen # `next_sentence_label` is the label for next sentence prediction, 0 or 1 data = data_gen_for_lm() - data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64) + data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64) return data @@ -44,7 +44,7 @@ def data_gen_for_sequence_classification(): # sequence classification data gen # `labels` is the label for sequence classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -52,7 +52,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data @@ -67,32 +67,276 @@ def data_gen_for_mcq(): # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) # data = {k: v.unsqueeze(0) for k, v in encoding.items()} # data['labels'] = torch.tensor([0], dtype=torch.int64) - input_ids = torch.tensor([[[ - 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442, - 1012, 102, 102 - ], - [ - 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, - 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0 - ]]]) - token_type_ids = torch.tensor([[[ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1 - ], - [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 - ]]]) - attention_mask = torch.tensor([[[ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1 - ], - [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 - ]]]) + input_ids = torch.tensor( + [ + [ + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2007, + 1037, + 9292, + 1998, + 1037, + 5442, + 1012, + 102, + 102, + 5442, + 1012, + 102, + 102, + ], + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2096, + 2218, + 1999, + 1996, + 2192, + 1012, + 102, + 0, + 0, + 1012, + 102, + 0, + 0, + ], + ] + ] + ) + token_type_ids = torch.tensor( + [ + [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) + attention_mask = torch.tensor( + [ + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) @@ -103,9 +347,9 @@ def data_gen_for_qa(): # no need for labels and use start and end position instead data = data_gen() start_positions = torch.tensor([0], dtype=torch.int64) - data['start_positions'] = start_positions + data["start_positions"] = start_positions end_positions = torch.tensor([1], dtype=torch.int64) - data['end_positions'] = end_positions + data["end_positions"] = end_positions return data @@ -114,69 +358,90 @@ def data_gen_for_qa(): # define loss funciton -loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state - )) +loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = transformers.BertConfig(hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256, - hidden_dropout_prob=0, - attention_probs_dropout_prob=0) +config = transformers.BertConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0, +) # register the BERT variants -model_zoo.register(name='transformers_bert', - model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_bert_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_pretraining', - model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_for_pretraining, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_lm_head_model', - model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_masked_lm', - model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_sequence_classification', - model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_token_classification', - model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_next_sentence', - model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_mcq', - model_fn=lambda: transformers.BertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_question_answering', - model_fn=lambda: transformers.BertForQuestionAnswering(config), - data_gen_fn=data_gen_for_qa, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_bert", + model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_pretraining", + model_fn=lambda: transformers.BertForPreTraining(config), + data_gen_fn=data_gen_for_pretraining, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_lm_head_model", + model_fn=lambda: transformers.BertLMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_masked_lm", + model_fn=lambda: transformers.BertForMaskedLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_sequence_classification", + model_fn=lambda: transformers.BertForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_token_classification", + model_fn=lambda: transformers.BertForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_next_sentence", + model_fn=lambda: transformers.BertForNextSentencePrediction(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_mcq", + model_fn=lambda: transformers.BertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_question_answering", + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 984a6ffa920d..887b11c7f54e 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -47,16 +47,20 @@ def data_gen(): config.text_config.dropout = 0 # register the blip2 variants -model_zoo.register(name='transformers_blip2', - model_fn=lambda: transformers.Blip2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_blip2_model, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_blip2_conditional_gerneration', - model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_blip2_model, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_blip2", + model_fn=lambda: transformers.Blip2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_blip2_conditional_gerneration", + model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 2d9c882089cb..12dcd71d5d1b 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -25,7 +25,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -33,14 +33,14 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data['labels'] = torch.tensor([0], dtype=torch.int64) + data["labels"] = torch.tensor([0], dtype=torch.int64) return data @@ -54,62 +54,69 @@ def data_gen_for_question_answering(): input_ids = torch.tensor( [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], - dtype=torch.int64) + dtype=torch.int64, + ) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64) - return dict(input_ids=input_ids, - attention_mask=attention_mask, - start_positions=start_positions, - end_positions=end_positions) + return dict( + input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions + ) # define output transform function output_transform_fn = lambda x: x # define loss function -loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, - torch.ones_like(x.last_hidden_state)) +loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn_for_causal_lm = lambda x: x.loss loss_fn_for_classification = lambda x: x.loss loss_fn_for_question_answering = lambda x: x.loss -config = transformers.BloomConfig(n_layer=2, - n_head=4, - vocab_size=250880, - hidden_dropout=0, - attention_dropout=0, - hidden_size=64, - pad_token_id=50256) +config = transformers.BloomConfig( + n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 +) # register the following models -model_zoo.register(name='transformers_bloom', - model_fn=lambda: transformers.BloomModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_bloom_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_causal_lm', - model_fn=lambda: transformers.BloomForCausalLM(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_causal_lm, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_sequence_classification', - model_fn=lambda: transformers.BloomForSequenceClassification(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_classification, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_token_classification', - model_fn=lambda: transformers.BloomForTokenClassification(config), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_classification, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_question_answering', - model_fn=lambda: transformers.BloomForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_question_answering, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_bloom", + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_causal_lm", + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_sequence_classification", + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_token_classification", + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_question_answering", + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index d543df00bdfa..22885bec224a 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -1,5 +1,4 @@ import torch -import transformers from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel @@ -21,8 +20,8 @@ def data_gen_for_conditional_generation(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data @@ -30,29 +29,36 @@ def data_gen_for_conditional_generation(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, - torch.ones_like(x.last_hidden_state)) +loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = ChatGLMConfig(num_layers=2, - padded_vocab_size=65024, - hidden_size=64, - num_attention_heads=8, - rmsnorm=True, - original_rope=True, - use_cache=True, - torch_dtype=torch.float32) - -model_zoo.register(name='transformers_chatglm', - model_fn=lambda: ChatGLMModel(config, empty_init=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_chatglm_model, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name="transformers_chatglm_for_conditional_generation", - model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +model_zoo.register( + name="transformers_chatglm", + model_fn=lambda: ChatGLMModel(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 0198e04689ea..2af6176fbe4a 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -27,7 +27,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -36,9 +36,9 @@ def data_gen_for_question_answering(): # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() start_positions = torch.tensor([0], dtype=torch.int64) - data['start_positions'] = start_positions + data["start_positions"] = start_positions end_positions = torch.tensor([1], dtype=torch.int64) - data['end_positions'] = end_positions + data["end_positions"] = end_positions return data @@ -46,14 +46,14 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data['labels'] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -62,7 +62,8 @@ def date_gen_for_double_heads(): batch_size = 2 input_ids = torch.tensor( [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], - dtype=torch.int64) + dtype=torch.int64, + ) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) @@ -85,58 +86,73 @@ def date_gen_for_double_heads(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state - )) +loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = transformers.GPT2Config(n_layer=2, - n_head=4, - vocab_size=50258, - attn_pdrop=0, - embd_pdrop=0, - resid_pdrop=0, - summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification", - pad_token_id=50256) +config = transformers.GPT2Config( + n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256, +) config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 # register the following models -model_zoo.register(name='transformers_gpt', - model_fn=lambda: transformers.GPT2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_gpt2_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_lm', - model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_double_heads', - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=date_gen_for_double_heads, - output_transform_fn=output_transform_fn, - loss_fn=lambda x: x.loss + x.mc_loss, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_question_answering', - model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_token_classification', - model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_sequence_classification', - model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_gpt", + model_fn=lambda: transformers.GPT2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_lm", + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_double_heads", + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_question_answering", + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_token_classification", + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_sequence_classification", + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 2018f3b4f440..bc229b17e08c 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -4,7 +4,8 @@ from ..registry import ModelAttribute, model_zoo try: - from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel + from transformers import LlamaConfig + HAS_LLAMA = True except ImportError: HAS_LLAMA = False @@ -33,8 +34,8 @@ def data_gen(): # label is needed for casual lm def data_gen_for_casual_lm(): data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data # transform the output to a dict @@ -45,12 +46,14 @@ def data_gen_for_casual_lm(): loss_fn_for_casual_lm = lambda output: output.loss loss_fn_for_seq_classification = lambda output: output.logits.mean() - config = LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128, - num_labels=16) + config = LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + ) if hasattr(config, "pad_token_id"): config.pad_token_id = config.eos_token_id @@ -59,21 +62,27 @@ def data_gen_for_casual_lm(): # transformers.LlamaModel, # transformers.LlamaForCausalLM, # transformers.LlamaForSequenceClassification, - model_zoo.register(name='transformers_llama', - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) - model_zoo.register(name='transformers_llama_for_casual_lm', - model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, - model_attribute=ModelAttribute(has_control_flow=True)) - model_zoo.register(name='transformers_llama_for_sequence_classification', - model_fn=lambda: transformers.LlamaForSequenceClassification(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_seq_classification, - model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_casual_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_sequence_classification", + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index a258e12ac127..07ca41ef21ae 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -20,8 +20,8 @@ def data_gen_for_causal_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data @@ -29,8 +29,8 @@ def data_gen_for_sequence_classification(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = torch.tensor([1]) + data["input_ids"].clone() + data["labels"] = torch.tensor([1]) return data @@ -38,14 +38,15 @@ def data_gen_for_question_answering(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['start_positions'] = torch.tensor([0]) - data['end_positions'] = torch.tensor([1]) + data["start_positions"] = torch.tensor([0]) + data["end_positions"] = torch.tensor([1]) return data output_transform_fn = lambda x: x -loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state) - ) +loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn_for_lm = lambda x: x.loss config = transformers.OPTConfig( hidden_size=128, @@ -57,24 +58,30 @@ def data_gen_for_question_answering(): # register the following models # transformers.OPTModel, # transformers.OPTForCausalLM, -model_zoo.register(name='transformers_opt', - model_fn=lambda: transformers.OPTModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_opt_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_opt_for_causal_lm', - model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen_for_causal_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_lm, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_opt_for_question_answering', - model_fn=lambda: transformers.OPTForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_lm, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_opt", + model_fn=lambda: transformers.OPTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_causal_lm", + model_fn=lambda: transformers.OPTForCausalLM(config), + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_question_answering", + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) # TODO The loss and gradient check in the test are failing, to be fixed. # model_zoo.register(name='transformers_opt_for_sequence_classification', diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py index d850623f368f..b928a8f14e75 100644 --- a/tests/kit/model_zoo/transformers/sam.py +++ b/tests/kit/model_zoo/transformers/sam.py @@ -28,10 +28,12 @@ def data_gen(): original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64) reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64) input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64) - return dict(pixel_values=pixel_values, - original_sizes=original_sizes, - reshaped_input_sizes=reshaped_input_sizes, - input_points=input_points) + return dict( + pixel_values=pixel_values, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + input_points=input_points, + ) # define output transform function @@ -44,9 +46,11 @@ def data_gen(): config.vision_config.num_hidden_layers = 2 # register the BERT variants -model_zoo.register(name='transformers_sam', - model_fn=lambda: transformers.SamModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_sam", + model_fn=lambda: transformers.SamModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 16a594f3950a..1b63cccc42ee 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -27,7 +27,7 @@ def data_gen_for_conditional_generation(): # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() - data['labels'] = labels + data["labels"] = labels return data @@ -36,7 +36,7 @@ def data_gen_for_t5_model(): # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() - data['decoder_input_ids'] = decoder_input_ids + data["decoder_input_ids"] = decoder_input_ids return data @@ -55,21 +55,27 @@ def data_gen_for_t5_model(): # transformers.T5Model, # transformers.T5ForConditionalGeneration, # transformers.T5EncoderModel, -model_zoo.register(name='transformers_t5', - model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen_for_t5_model, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_t5_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_for_conditional_generation', - model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_conditional_generation, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_encoder_model', - model_fn=lambda: transformers.T5EncoderModel(config), - data_gen_fn=data_gen_for_encoder_only, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_encoder_only, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_t5", + model_fn=lambda: transformers.T5Model(config), + data_gen_fn=data_gen_for_t5_model, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_for_conditional_generation", + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_encoder_model", + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index a84b8d31c284..f1990751b016 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -18,15 +18,15 @@ def data_gen(): def data_gen_for_image_classification(): data = data_gen() - data['labels'] = torch.tensor([0]) + data["labels"] = torch.tensor([0]) return data def data_gen_for_masked_image_modeling(): data = data_gen() - num_patches = (config.image_size // config.patch_size)**2 + num_patches = (config.image_size // config.patch_size) ** 2 bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() - data['bool_masked_pos'] = bool_masked_pos + data["bool_masked_pos"] = bool_masked_pos return data @@ -42,23 +42,29 @@ def data_gen_for_masked_image_modeling(): # transformers.ViTModel, # transformers.ViTForMaskedImageModeling, # transformers.ViTForImageClassification, -model_zoo.register(name='transformers_vit', - model_fn=lambda: transformers.ViTModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_vit_model, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit", + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='transformers_vit_for_masked_image_modeling', - model_fn=lambda: transformers.ViTForMaskedImageModeling(config), - data_gen_fn=data_gen_for_masked_image_modeling, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_masked_image_modeling, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit_for_masked_image_modeling", + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='transformers_vit_for_image_classification', - model_fn=lambda: transformers.ViTForImageClassification(config), - data_gen_fn=data_gen_for_image_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_image_classification, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit_for_image_classification", + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen_for_image_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index f7cdc052aaf0..928be4468c01 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -33,7 +33,7 @@ def data_gen_for_conditional_generation(): # or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is # only computed for the tokens with labels in `[0, ..., config.vocab_size]`. data = data_gen() - data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64) return data @@ -44,8 +44,8 @@ def data_gen_for_audio_classification(): # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). # `WhisperForAudioClassification` does not need `decoder_input_ids` data = data_gen() - data.pop('decoder_input_ids') - data['labels'] = torch.tensor([1], dtype=torch.int64) + data.pop("decoder_input_ids") + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -69,23 +69,29 @@ def data_gen_for_audio_classification(): ) # register the Whisper variants -model_zoo.register(name='transformers_whisper', - model_fn=lambda: transformers.WhisperModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_whisper_for_conditional_generation', - model_fn=lambda: transformers.WhisperForConditionalGeneration(config), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_attr, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_whisper_for_audio_classification', - model_fn=lambda: transformers.WhisperForAudioClassification(config), - data_gen_fn=data_gen_for_audio_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_attr, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_whisper", + model_fn=lambda: transformers.WhisperModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_conditional_generation", + model_fn=lambda: transformers.WhisperForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_audio_classification", + model_fn=lambda: transformers.WhisperForAudioClassification(config), + data_gen_fn=data_gen_for_audio_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index f7b5eb140f24..f72c1cb3f533 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -12,7 +12,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -23,25 +22,14 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x, select=0): if select == 0: @@ -52,7 +40,6 @@ def forward(self, x, select=0): class SiuModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear = LinearModel(3, 3, bias) @@ -69,7 +56,6 @@ def forward(self, x, select=torch.Tensor([0])): class AddmmModel(torch.nn.Module): - def __init__(self, alpha, beta) -> None: super().__init__() self.alpha = alpha @@ -80,7 +66,7 @@ def forward(self, x): return x -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -89,19 +75,21 @@ def forward(self, x): def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, - meta_args={'x': x}, - concrete_args={'select': select}, - trace_act_ckpt=True, - bias_addition_split=bias_addition_split) - assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace( + model, + meta_args={"x": x}, + concrete_args={"select": select}, + trace_act_ckpt=True, + bias_addition_split=bias_addition_split, + ) + assert torch.allclose(model(x, select), gm(x)), "original model and traced model should be the same!" if bias and bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" else: - assert '+' not in gm.code, 'bias addition should not be split!' + assert "+" not in gm.code, "bias addition should not be split!" -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @parameterize("alpha", [1, 2]) @parameterize("beta", [1, 2]) @parameterize("bias_addition_split", [True, False]) @@ -109,14 +97,14 @@ def test_siu_model(bias, bias_addition_split, shape, select): def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) - assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) + assert torch.allclose(model(x), gm(x)), "original model and traced model should be the same!" if (alpha == 1 and beta == 1) or not bias_addition_split: - assert '*' not in gm.code, 'bias addition should not be split!' + assert "*" not in gm.code, "bias addition should not be split!" elif bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" -if __name__ == '__main__': +if __name__ == "__main__": test_siu_model() test_addmm_model() diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py index f62147b297a2..be151b1edd80 100644 --- a/tests/test_analyzer/test_fx/test_mod_dir.py +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -10,7 +10,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -21,25 +20,14 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(out_channels, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + out_channels, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x): x = self.conv(x) @@ -48,7 +36,6 @@ def forward(self, x): class AModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear_1 = LinearModel(3, 3, bias) @@ -63,7 +50,7 @@ def forward(self, x): return x -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -71,11 +58,11 @@ def forward(self, x): def test_mod_dir(bias, bias_addition_split, shape): model = AModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split) + gm = symbolic_trace(model, meta_args={"x": x}, bias_addition_split=bias_addition_split) for node in gm.graph.nodes: - assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``." - print(node, node.meta['info'].mod_dir) + assert len(node.meta["info"].mod_dir), f"{node} should have non-trivial ``mod_dir``." + print(node, node.meta["info"].mod_dir) -if __name__ == '__main__': +if __name__ == "__main__": test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index bd16f5a4f95d..d7b96fb9f043 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -12,7 +12,6 @@ class MyModule(nn.Module): - def __init__(self): super().__init__() self.a = nn.Linear(10, 10) @@ -43,14 +42,14 @@ def forward(self, x): return checkpoint(self.checkpoint_0, x) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_nested_ckpt(): model = MyModule() x = torch.rand(10, 10) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True) + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True) assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model." - for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)): + for ckpt_def in filter(lambda s: s.startswith("checkpoint"), dir(model)): assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}" diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index a849feb795e5..609fc9c7b022 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -16,24 +15,25 @@ def linear_impl(*args, **kwargs): assert True return torch.nn.functional.linear(*args, **kwargs) + except: pass def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.' + assert node.meta["info"].outputs, f"In {gm.__class__.__name__}, {node} has no output shape." if node.op in [ - 'call_module', # can apply to params - 'call_function', # can apply to params - 'call_method', # can apply to params + "call_module", # can apply to params + "call_function", # can apply to params + "call_method", # can apply to params ]: - assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.' + assert hasattr(node.meta["info"], "inputs"), f"In {gm.__class__.__name__}, {node} has no input shape." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): model = m() @@ -46,9 +46,9 @@ def test_torchvision_shape_prop(m): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index 17deee7a7118..8d8ee2445d58 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -15,12 +14,12 @@ def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.' + assert len(node.meta["info"].global_ctx), f"In {gm.__class__.__name__}, {node} has empty global context." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() @@ -33,9 +32,9 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index b7858110ac09..61c1d25f7b3d 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -14,35 +14,41 @@ aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -51,15 +57,20 @@ (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -73,7 +84,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -81,5 +92,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index 4e9c9852649b..b1b9a89fad97 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -4,7 +4,6 @@ import torchvision.models as tm from packaging import version -from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -13,40 +12,44 @@ pass -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models + tmm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("m", tm_models + tmm_models) def test_flop_count_module(m): x = torch.rand(2, 3, 224, 224) - with MetaTensorMode(): # save time for testing + with MetaTensorMode(): # save time for testing module = m() rs_fwd, rs_bwd = flop_count(module, x, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}" odd_cases = [ - (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'inplace': True - }), - (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'kernel_size': 3, - 'stride': 2, - 'padding': 1, - 'dilation': 2 - }), - (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True), - torch.rand(2, 3, 224, 224, requires_grad=True)), {}), + (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}), + ( + F.max_pool2d, + (torch.rand(2, 3, 224, 224, requires_grad=True),), + {"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2}, + ), + ( + torch.where, + ( + torch.rand(2, 3, 224, 224) > 0.5, + torch.rand(2, 3, 224, 224, requires_grad=True), + torch.rand(2, 3, 224, 224, requires_grad=True), + ), + {}, + ), ] -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('func, args, kwargs', odd_cases) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("func, args, kwargs", odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}" -if __name__ == '__main__': +if __name__ == "__main__": test_flop_count_module(tm.resnet18) - test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True}) + test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}) diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index d2a0a1b9cfb5..c55c4ec42703 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -6,17 +6,22 @@ from colossalai.testing import clear_cache_before_run, parameterize try: - from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode + from colossalai._analyzer._subclasses import MetaTensorMode except: pass from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor): - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(model): @@ -31,12 +36,12 @@ def run_and_compare(model): compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models + tmm_models) +@parameterize("m", tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_mode_shape(tm.resnet18) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index b65e6d0d8863..03bba8e64772 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -8,6 +8,7 @@ import colossalai from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta + # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -19,18 +20,18 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True except: - from colossalai.fx.codegen import python_code_with_activation_checkpoint withcodegen = False def _run_C_solver_consistency_test(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() - data = torch.rand(128, 3, 224, 224, device='meta') + data = torch.rand(128, 3, 224, 224, device="meta") tracer = ColoTracer() graph = tracer.trace(model, meta_args={"x": data}) @@ -54,15 +55,17 @@ def _run_C_solver_consistency_test(rank, world_size, port): for m in range(len(opt_python)): for d in range(1, len(opt_python[0])): for i in range(len(opt_python[0]) - d): - assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ - f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" + assert ( + opt_python[m][i][i + d] == opt_C[m][i][i + d] + ), f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" sequence_python = sequence_python.list_operations() sequence_C = sequence_C.list_operations() # make sure the sequences are the same - assert len(sequence_python) == len(sequence_C) and \ - all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + assert len(sequence_python) == len(sequence_C) and all( + python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C) + ) gpc.destroy() @@ -74,5 +77,5 @@ def test_C_solver_consistency(): spawn(_run_C_solver_consistency_test, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_C_solver_consistency_test(rank=0) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index babdddfada18..c46f57f75303 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -11,6 +11,7 @@ from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.legacy.core import global_context as gpc @@ -21,10 +22,12 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False # SOLVERS = [chen_greedy, solver_rotor] @@ -33,7 +36,7 @@ def _is_activation_checkpoint_available(gm: GraphModule): for n in gm.graph.nodes: - if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + if hasattr(n, "activation_checkpoint") and getattr(n, "activation_checkpoint") is not None: return True @@ -47,15 +50,19 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): def _is_graph_linearized(gm: GraphModule): code = gm.code # find patterns like r' return output_1, output_2', which is not expected on a linearized graph - pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + pattern = re.compile(r" return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+") if pattern.findall(code): return False else: return True -def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], - model_cls: Callable[[], torch.nn.Module]): +def check_backward_consistency( + m: torch.nn.Module, + gm: GraphModule, + solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module], +): criterion = torch.nn.MSELoss() m.cuda() data = torch.rand(2, 3, 32, 32).cuda() @@ -64,18 +71,18 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call loss.backward() loss = criterion(gm(data), label) loss.backward() - assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + assert _is_all_gradient_close(m, gm), f"Solver {solver} did not work correctly in backward pass on {model_cls}" def _run_ckpt_solver(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 224, 224, device='meta') + data = torch.rand(8, 3, 224, 224, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -90,27 +97,28 @@ def _run_ckpt_solver(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() @pytest.mark.skip("TODO(super-dainiu): refactor all tests.") -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_ckpt_solver(): spawn(_run_ckpt_solver, 1) def _run_ckpt_solver_torch11(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 32, 32, device='meta') + data = torch.rand(8, 3, 32, 32, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -124,19 +132,20 @@ def _run_ckpt_solver_torch11(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_ckpt_solver_torch11(): spawn(_run_ckpt_solver_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_ckpt_solver(rank=0) test_ckpt_solver() test_ckpt_solver_torch11() diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py index 59880815dc5e..bb3be9344566 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -15,14 +16,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False -@pytest.mark.skip(reason='TODO: modify the logger') +@pytest.mark.skip(reason="TODO: modify the logger") @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @clear_cache_before_run() @@ -35,12 +38,12 @@ def test_linearize(): graph = tracer.trace(model) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) - MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device="cpu")) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): @@ -48,8 +51,9 @@ def test_linearize(): if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -65,8 +69,9 @@ def test_linearize(): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -75,8 +80,9 @@ def test_linearize(): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" del model del gm @@ -100,7 +106,7 @@ def test_linearize_torch11(): gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py index c22b17ae42ba..0efe84655aac 100644 --- a/tests/test_auto_parallel/test_offload/model_utils.py +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -1,25 +1,23 @@ import torch import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel -from transformers import BertConfig, BertLMHeadModel +from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel + from tests.components_to_test.registry import non_distributed_component_funcs -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -27,7 +25,6 @@ def forward(self, input_ids, attention_mask): class LMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -38,18 +35,27 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + class BertLMModel(nn.Module): def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): super().__init__() - self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, - num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, - vocab_size=vocab_size)) + self.model = BertLMHeadModel( + BertConfig( + n_embd=hidden_size, + num_hidden_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=hidden_size, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] -@non_distributed_component_funcs.register(name='bert_') + +@non_distributed_component_funcs.register(name="bert_") def get_bert_components(): vocab_size = 1024 seq_len = 64 @@ -67,7 +73,8 @@ def bert_data_gen(device="meta"): return bert_model_builder, bert_data_gen -@non_distributed_component_funcs.register(name='gpt2_') + +@non_distributed_component_funcs.register(name="gpt2_") def get_gpt2_components(): vocab_size = 1024 seq_len = 8 @@ -83,4 +90,4 @@ def gpt2_data_gen(device="meta"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 45c22efc4127..2c8b260e6498 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -17,18 +17,22 @@ from tests.test_tensor.common_utils import set_seed -@parameterize('model_name', ['gpt2_']) -@parameterize('memory_budget', [5000]) -@parameterize('solver_name', ['asyn']) +@parameterize("model_name", ["gpt2_"]) +@parameterize("memory_budget", [5000]) +@parameterize("solver_name", ["asyn"]) def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): - # build model get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = LMLoss() set_seed(42) @@ -50,17 +54,19 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) optim = AMPOptimizer(hybrid_optimizer, model) - with ColoInitContext(device=torch.device('cpu')): + with ColoInitContext(device=torch.device("cpu")): gemini_model = model_builder() gemini_model.train() hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_config = dict(strict_ddp_mode=False, - device=torch.device('cpu'), - placement_policy='cpu', - pin_memory=True, - hidden_dim=8192, - search_range_m=128) + gemini_config = dict( + strict_ddp_mode=False, + device=torch.device("cpu"), + placement_policy="cpu", + pin_memory=True, + hidden_dim=8192, + search_range_m=128, + ) gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) @@ -89,9 +95,11 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'gemini | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"gemini | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) del data_args @@ -124,24 +132,26 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_name: {solver_name} | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_name: {solver_name} | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_fwd_bwd() @pytest.mark.skip("this test failed") -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @rerun_if_address_is_in_use() def test_perf(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_perf() diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index aa2c9a36849f..6bb53aa67495 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -11,13 +11,12 @@ from tests.test_auto_parallel.test_offload.model_utils import * -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @clear_cache_before_run() -@parameterize('model_name', ['gpt2_', 'bert_']) -@parameterize('memory_budget', [4000]) -@parameterize('solver_name', ['syn', 'asyn']) +@parameterize("model_name", ["gpt2_", "bert_"]) +@parameterize("memory_budget", [4000]) +@parameterize("solver_name", ["syn", "asyn"]) def solver_test(model_name: str, memory_budget: float, solver_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() data_args = data_gen(device="cpu") @@ -53,15 +52,15 @@ def solver_test(model_name: str, memory_budget: float, solver_name: str): need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None print( - f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None print( - f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) -if __name__ == '__main__': +if __name__ == "__main__": solver_test() diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py index 429e89aae5d3..2b89a73656b1 100644 --- a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass from colossalai.device.device_mesh import DeviceMesh @@ -10,7 +9,6 @@ class TestModule(torch.nn.Module): - def forward(self, x): x = x.view(4, 4, 2) return x @@ -19,7 +17,7 @@ def forward(self, x): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) view_node = list(x_node.users.keys())[0] new_args = list(view_node.args) new_args[0] = shard_node @@ -33,7 +31,7 @@ def test_node_args_converting_pass(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer() graph = tracer.trace(root=model, meta_args=meta_args) @@ -41,8 +39,8 @@ def test_node_args_converting_pass(): x_node = list(graph.nodes)[0] view_node = list(graph.nodes)[1] sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', sharding_spec) - setattr(view_node, 'sharding_spec', sharding_spec) + setattr(x_node, "sharding_spec", sharding_spec) + setattr(view_node, "sharding_spec", sharding_spec) gm = ColoGraphModule(model, graph) gm = node_args_converting_pass(gm, device_mesh) @@ -52,5 +50,5 @@ def test_node_args_converting_pass(): assert output.shape == torch.Size([2, 4, 2]) -if __name__ == '__main__': +if __name__ == "__main__": test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index bca81201c6ef..b6cc6c9b44fd 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -1,6 +1,5 @@ import pytest import torch -import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass @@ -12,7 +11,6 @@ class TestModule(torch.nn.Module): - def forward(self, x): size = x.size() return size @@ -21,7 +19,7 @@ def forward(self, x): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) size_node = list(x_node.users.keys())[0] size_node.args = (shard_node,) return gm @@ -36,20 +34,20 @@ def recover_narrow(gm, narrow_node): return gm -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) x_node = list(graph.nodes)[0] x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', x_sharding_spec) + setattr(x_node, "sharding_spec", x_sharding_spec) gm = ColoGraphModule(model, graph) gm = insert_narrow(gm, x_node) shape_prop_pass(gm, *meta_args.values()) @@ -66,5 +64,5 @@ def test_size_value_converting_pass(): assert size == torch.Size([4, 8]) -if __name__ == '__main__': +if __name__ == "__main__": test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index 9fbe674ef4f4..c41c66745012 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -1,10 +1,9 @@ -from functools import partial - import pytest import torch try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -16,7 +15,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -29,13 +27,11 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -46,7 +42,7 @@ def forward(self, x): def check_linear_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel(4, 8).cuda() input = torch.rand(4, 4).cuda() output_compare = model(input) @@ -55,7 +51,7 @@ def check_linear_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} + meta_args = {"x": torch.rand(4, 4).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) @@ -63,7 +59,7 @@ def check_linear_module(rank, world_size, port): def check_conv_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel(3, 6, 2).cuda() input = torch.rand(4, 3, 64, 64).cuda() output_compare = model(input) @@ -72,14 +68,14 @@ def check_conv_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 3, 64, 64).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): @@ -87,5 +83,5 @@ def test_bias_addition_module(): spawn(check_conv_module, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 5607587496f3..5cc1820837bb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -48,17 +48,15 @@ def test_recover_sharding_spec_for_broadcast_shape(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) - logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, - dim_partition_dict={ - 0: [0], - 1: [1] - }, - entire_shape=broadcast_shape) + logical_sharding_spec_for_x1 = ShardingSpec( + device_mesh=device_mesh, dim_partition_dict={0: [0], 1: [1]}, entire_shape=broadcast_shape + ) physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec_for_x1, broadcast_shape, x1.shape) + logical_sharding_spec_for_x1, broadcast_shape, x1.shape + ) print(physical_sharding_spec_for_x1) assert physical_sharding_spec_for_x1.entire_shape == x1.shape # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} - assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] + assert physical_sharding_spec_for_x1.sharding_sequence == ["S0", "R", "R"] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 398458306e3d..c800f54da66c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -8,6 +8,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -21,7 +22,6 @@ class GPT2MLPWithCkpt(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() embed_dim = hidden_size @@ -39,11 +39,11 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl def check_act_ckpt(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) - input = torch.rand(1, 64, HIDDEN_SIZE) + torch.rand(1, 64, HIDDEN_SIZE) input_sample = { - 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + "hidden_states": torch.rand(1, 64, HIDDEN_SIZE).to("meta"), } physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -51,18 +51,24 @@ def check_act_ckpt(rank, world_size, port): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) gm = initialize_model(model, input_sample, device_mesh) - code = gm.module.graph.python_code('self').src - assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code - assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code + code = gm.module.graph.python_code("self").src + assert ( + "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" + in code + ) + assert ( + "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" + in code + ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): spawn(check_act_ckpt, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 6908a1781869..e8f175326bb1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -6,6 +6,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -17,7 +18,6 @@ class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -32,7 +32,7 @@ def forward(self, x): def check_compatibility_with_ddp(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).cuda() if rank in [0, 1]: input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() @@ -49,26 +49,28 @@ def check_compatibility_with_ddp(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') - - msg = '| TP strategy combination chosen by auto-parallel solver |' + meta_args = {"x": torch.rand(4, 4).to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) + + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) if rank == 0: - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) dp_process_group = None - for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + for ranks, process_group_handle in device_mesh.process_groups_dict[0]: if rank in ranks: dp_process_group = process_group_handle assert dp_process_group is not None @@ -79,7 +81,7 @@ def check_compatibility_with_ddp(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() loss.backward() @@ -90,16 +92,16 @@ def check_compatibility_with_ddp(rank, world_size, port): if rank in (1, 3): assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): spawn(check_compatibility_with_ddp, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 715f62358e2d..aba746f1992d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -5,6 +5,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -19,7 +20,6 @@ class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -34,7 +34,7 @@ def forward(self, x): def check_auto_parallel_with_gemini(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).half().cuda() if rank in [0, 1]: input = torch.arange(0, 16).reshape(4, 4).half().cuda() @@ -51,29 +51,29 @@ def check_auto_parallel_with_gemini(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).half().to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') + meta_args = {"x": torch.rand(4, 4).half().to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) if rank == 0: - msg = '| TP strategy combination chosen by auto-parallel solver |' + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) - gemini_config = dict(strict_ddp_mode=False, - device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - search_range_m=128) + gemini_config = dict( + strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) @@ -83,28 +83,28 @@ def check_auto_parallel_with_gemini(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() optimizer.zero_grad() optimizer.backward(loss) optimizer.step() if rank in (0, 2): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 0, 8).flatten()) if rank in (1, 3): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 8, 8).flatten()) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): spawn(check_auto_parallel_with_gemini, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_auto_parallel_with_gemini() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index a0b407b240e1..a0276acc4293 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -5,8 +5,8 @@ from torch.fx import GraphModule from transformers.pytorch_utils import Conv1D -from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks @@ -19,7 +19,6 @@ class RepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() self.c_fc = Conv1D(intermediate_size, hidden_size) @@ -35,13 +34,11 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl class RepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) @@ -49,10 +46,9 @@ def forward(self, x): class NonRepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size, layer_index): super().__init__() - intermediate_size //= (layer_index + 1) + intermediate_size //= layer_index + 1 self.c_fc = Conv1D(intermediate_size, hidden_size) self.c_proj = Conv1D(hidden_size, intermediate_size) self.act = torch.nn.ReLU() @@ -66,28 +62,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl class NonRepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [RepeatModel, NonRepeatModel]) +@parameterize("model_cls", [RepeatModel, NonRepeatModel]) def test_repeat_blocks(model_cls): - model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) tracer = ColoTracer(bias_addition_split=True) - input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} + input_sample = {"x": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -110,5 +103,5 @@ def test_repeat_blocks(model_cls): assert len(common_blocks) == 0 -if __name__ == '__main__': +if __name__ == "__main__": test_repeat_blocks() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py index 22a2371311f9..3bb7cc409938 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -8,7 +8,6 @@ class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -34,15 +33,15 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -68,7 +67,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -76,7 +75,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -100,7 +99,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -113,7 +112,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) qkv = self.c_attn(hidden_states) @@ -121,7 +119,7 @@ def forward( # key = self._split_heads(key, self.num_heads, self.head_dim) # value = self._split_heads(value, self.num_heads, self.head_dim) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) @@ -131,7 +129,6 @@ def forward( class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -205,11 +202,9 @@ def forward( # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -267,7 +262,6 @@ def forward( class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 48d2672c6571..24968e670e3f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -9,6 +9,7 @@ from torch.fx import GraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer @@ -19,6 +20,7 @@ solve_solution, transform_to_sharded_model, ) + NO_CODEGEN = False except: NO_CODEGEN = True @@ -45,14 +47,17 @@ torch.backends.cudnn.benchmark = False -def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], - best_sharding_spec_dict: Dict[str, ShardingSpec]): +def _check_module_grad( + module: torch.nn.Module, + origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec], +): for name, param in module.named_parameters(): param_grad = param.grad - name = name.replace('module.', '') + name = name.replace("module.", "") origin_param_grad = origin_param_dict[name].grad - atoms = name.split('.') - new_name = '_'.join(atoms) + atoms = name.split(".") + new_name = "_".join(atoms) if new_name in best_sharding_spec_dict: param_sharding_spec = best_sharding_spec_dict[new_name] grad_to_compare = copy.deepcopy(param_grad) @@ -63,19 +68,19 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor difference = param_grad_global - origin_param_grad avg_diff = difference.abs().sum() / difference.numel() assert avg_diff < 0.001 - print(f'{name} param has {avg_diff} average difference') + print(f"{name} param has {avg_diff} average difference") def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: - model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to("cuda") else: - model = model_cls(config=config).to('cuda') + model = model_cls(config=config).to("cuda") test_model = copy.deepcopy(model) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) @@ -84,30 +89,30 @@ def check_attention_layer(rank, model_cls, world_size, port): hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) if model_cls == GPT2MLP: - input_sample = (hidden_states.to('cuda'),) + input_sample = (hidden_states.to("cuda"),) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), + "hidden_states": hidden_states.to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = ( - hidden_states.to('cuda'), - attention_mask.to('cuda'), + hidden_states.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "hidden_states": hidden_states.to("meta"), + "attention_mask": attention_mask.to("meta"), } else: input_sample = ( - input_ids.to('cuda'), - attention_mask.to('cuda'), + input_ids.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'input_ids': input_ids.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "input_ids": input_ids.to("meta"), + "attention_mask": attention_mask.to("meta"), } physical_mesh_id = torch.arange(0, 4) @@ -122,10 +127,11 @@ def check_attention_layer(rank, model_cls, world_size, port): shape_prop_pass(gm, *meta_input_sample.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + strategies_constructor = build_strategy_constructor(graph, device_mesh, "standard", "replicated", "standard") solution = solve_solution(gm, strategies_constructor, memory_budget=-1) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, - strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_input_sample, solution, device_mesh, strategies_constructor + ) gm = ModuleWrapper(gm, *sharding_spec_dicts) nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -141,7 +147,7 @@ def check_attention_layer(rank, model_cls, world_size, port): output = gm(*input_sample) assert_close(output, origin_output, rtol=1e-03, atol=1e-03) - #*******************backward starting******************* + # *******************backward starting******************* cuda_rng_state = torch.cuda.get_rng_state() cpu_rng_state = torch.get_rng_state() output.sum().backward() @@ -158,9 +164,9 @@ def check_attention_layer(rank, model_cls, world_size, port): if rank == 0: print("*******************backward finished*******************") - #*******************backward finished******************* + # *******************backward finished******************* - #*******************strategy selected******************* + # *******************strategy selected******************* if rank == 0: print("*******************strategy selected*******************") nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -176,19 +182,19 @@ def check_attention_layer(rank, model_cls, world_size, port): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") @pytest.mark.dist -@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) +@parameterize("model_cls", [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): spawn(check_attention_layer, 4, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 5a8c3c4bf5a0..b61cbe170820 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -4,7 +4,6 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh @@ -18,9 +17,9 @@ HIDDEN_DIM = 384 -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +@parameterize("model_cls", [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: @@ -32,23 +31,23 @@ def test_self_attention_block(model_cls): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) if model_cls == GPT2MLP: input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), + "attention_mask": torch.rand(1, SEQ_LENGTH).to("meta"), } else: input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) - input_sample = {k: v.to('meta') for k, v in kwargs.items()} + input_sample = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=input_sample) @@ -63,7 +62,7 @@ def test_self_attention_block(model_cls): cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1) - ret = solver.call_solver_serialized_args() + solver.call_solver_serialized_args() strategies_list = solver.last_s_val nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -79,10 +78,10 @@ def test_self_attention_block(model_cls): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -if __name__ == '__main__': +if __name__ == "__main__": test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index d10b222c060d..4dd04c69c8a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -11,7 +11,6 @@ class LinearModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(4, 4) @@ -27,12 +26,12 @@ def forward(self, x1, x2): return out -@pytest.mark.skip('meta tensor has some bugs in 1.11') +@pytest.mark.skip("meta tensor has some bugs in 1.11") @clear_cache_before_run() def test_liveness_analysis(): model = LinearModel() tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + meta_args = {"x1": torch.rand(4, 4, device="meta"), "x2": torch.rand(4, 4, device="meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) shape_prop_pass(gm, *meta_args.values()) @@ -46,8 +45,8 @@ def test_liveness_analysis(): # a variable named `relu` must exist # and this live var must have inplace = True - assert liveness_list[0].all_live_vars.exists('relu') - relu_var = liveness_list[0].all_live_vars.get('relu') + assert liveness_list[0].all_live_vars.exists("relu") + relu_var = liveness_list[0].all_live_vars.get("relu") assert relu_var.is_inplace # the unique vars must be fewer than the all vars since in-place ops exist @@ -56,5 +55,5 @@ def test_liveness_analysis(): assert len(unique_live_vars) + 1 == len(all_live_vars) -if __name__ == '__main__': +if __name__ == "__main__": test_liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index e0a2133e654e..8831a208cb2f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -7,14 +7,17 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() -@parameterize('func', [ - torch.nn.functional.softmax, - torch.nn.functional.relu, - torch.tanh, - torch.nn.functional.dropout, -]) +@parameterize( + "func", + [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, + torch.nn.functional.dropout, + ], +) def test_activation_meta_info(func): meta_func = meta_register.get(func) # construct meta tensors @@ -23,13 +26,13 @@ def test_activation_meta_info(func): softmax_dim = 0 # construct operation data - input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor) - output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor) - softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + softmax_dim_data = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) # construct args and kwargs args = [input_data, softmax_dim_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -54,9 +57,17 @@ def test_activation_meta_info(func): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_activation_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index 68ccc7835bc3..ba9e282144b7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -3,7 +3,6 @@ import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -12,7 +11,6 @@ class BinaryElementwiseOpModule(nn.Module): - def __init__(self, token=torch.add, shape=64) -> None: super().__init__() self.token = token @@ -33,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() input = torch.rand(32, 1024).cuda() input.requires_grad = True @@ -45,21 +43,23 @@ def _binary_elementwise_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_meta_concrete_info_match(): spawn(_binary_elementwise_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index c6f7b88f44a5..45558154547f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -11,7 +11,6 @@ class ConvFunctionModule(nn.Module): - def __init__(self, in_channels=4, out_channels=64, kernel_size=3): super().__init__() self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) @@ -32,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -44,16 +43,18 @@ def _conv_module_mem_test(rank, world_size, port, bias): node_index = 1 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_meta_concrete_info_match(bias=False): @@ -71,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvFunctionModule().cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -83,22 +84,24 @@ def _conv_function_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_function_concrete_info_match(): spawn(_conv_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_conv_meta_concrete_info_match() test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index e3f76a95c4a5..5d830d769c2d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -5,11 +5,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_embedding_meta_info(): meta_func = meta_register.get(torch.nn.Embedding) @@ -28,7 +28,7 @@ def test_embedding_meta_info(): # construct args and kwargs args = [input_data, weight_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -52,9 +52,17 @@ def test_embedding_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index fb3ded339ddf..639870c89a82 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -11,7 +11,6 @@ class MyModule(nn.Module): - def __init__(self, in_features=64, out_features=128): super().__init__() self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) @@ -31,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -40,16 +39,18 @@ def _linear_module_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=1, - strategy_number=13, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_module_meta_concrete_info_match(): @@ -67,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MyModule().cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -76,22 +77,24 @@ def _linear_function_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=2, - strategy_number=24, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=24, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_function_meta_concrete_info_match(): spawn(_linear_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_linear_module_meta_concrete_info_match() test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index 2d2d77f0c637..b182dd02ca76 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -5,26 +5,27 @@ from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[128], [128]], # dot product - [[64, 128], [128]], # mat-vec - [[128], [128, 64]], # vec-mat - [[64, 64, 128], [128]], # batched mat-vec - [[128], [64, 128, 64]], # vec-batched mat - [[64, 128], [128, 192]], # mat-mat - [[64, 64, 128], [128, 192]], # batched mat-mat - [[64, 128], [64, 128, 192]], # mat-batched mat - [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) - [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) - ]) + [[128], [128]], # dot product + [[64, 128], [128]], # mat-vec + [[128], [128, 64]], # vec-mat + [[64, 64, 128], [128]], # batched mat-vec + [[128], [64, 128, 64]], # vec-batched mat + [[64, 128], [128, 192]], # mat-mat + [[64, 64, 128], [128, 192]], # batched mat-mat + [[64, 128], [64, 128, 192]], # mat-batched mat + [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) + [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) + ], +) def test_matmul_function_meta_info(tensor_shapes): meta_func = meta_register.get(torch.matmul) @@ -55,7 +56,7 @@ def test_matmul_function_meta_info(tensor_shapes): # construct args and kwargs args = [input_data, other_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -85,9 +86,17 @@ def test_matmul_function_meta_info(tensor_shapes): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor, other_real_tensor], [output_real_tensor], compute_cost, memory_cost, - fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [input_real_tensor, other_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_function_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index 808172977b60..ed809a758dfd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -10,7 +10,7 @@ from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register @@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(128)).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -37,27 +37,32 @@ def _batchnorm_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_batchnorm_meta_concrete_info_match(): spawn(_batchnorm_module_mem_test, 4) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') -@parameterize('tensor_shape', [ - [256, 1024], - [1024, 256], -]) +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") +@parameterize( + "tensor_shape", + [ + [256, 1024], + [1024, 256], + ], +) def test_layernorm_meta_info(tensor_shape): meta_func = meta_register.get(torch.nn.LayerNorm) @@ -78,7 +83,7 @@ def test_layernorm_meta_info(tensor_shape): # construct args and kwargs args = [input_data, output_data, weight_data, bias_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -108,10 +113,18 @@ def test_layernorm_meta_info(tensor_shape): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_batchnorm_meta_concrete_info_match() test_layernorm_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index 4cddf4e19fca..bd1deb40ca7b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -33,16 +33,18 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target strategies strategy_number = 1 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_adaptiveavgpool_meta_concrete_info_match(): @@ -60,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -72,22 +74,24 @@ def _maxpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_maxpool_meta_concrete_info_match(): spawn(_maxpool_module_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_adaptiveavgpool_meta_concrete_info_match() test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index 6e8145885d67..a29291e9b4d9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -6,12 +6,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register class SplitModule(nn.Module): - def __init__(self) -> None: super().__init__() @@ -19,7 +18,7 @@ def forward(self, x): return x.split(512, dim=0) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_tensor_meta_info(): """test tensor related meta information @@ -45,7 +44,7 @@ def test_tensor_meta_info(): logical_shape=input_tensor.shape, ) split_info_data = OperationData( - name='split_info', + name="split_info", type=OperationDataType.ARG, data=0, logical_shape=None, @@ -53,7 +52,7 @@ def test_tensor_meta_info(): # construct args args = [input_data, output_data, split_info_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -79,8 +78,16 @@ def test_tensor_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + output_real_tensor, + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) if __name__ == "__main__": diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index b4564312eeb4..64d9ccd3def2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -5,11 +5,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_where_meta_info(): meta_func = meta_register.get(torch.where) @@ -49,7 +49,7 @@ def test_where_meta_info(): # construct args and kwargs args = [condition_data, x_data, y_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -81,9 +81,17 @@ def test_where_meta_info(): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost, - memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [condition_real_tensor, x_real_tensor, y_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_where_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 4ca85d34da30..e58d15cec50b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,6 +7,7 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass @@ -16,29 +17,34 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import ShardMetaInfo -def mem_test_for_node_strategy(rank: int, - model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}): +def mem_test_for_node_strategy( + rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, +): for strategy_index in range(strategy_number): # We need to copy the model to avoid do backward more than once in same graph - model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( - input_kwargs) + model_to_shard, args_to_shard, kwargs_to_shard = ( + copy.deepcopy(model), + copy.deepcopy(input_args), + copy.deepcopy(input_kwargs), + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -57,13 +63,18 @@ def mem_test_for_node_strategy(rank: int, # construct the strategy for the output node placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] - output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() - if key.type == OperationDataType.OUTPUT) + output_key = next( + key + for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT + ) placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ - output_key] + output_key + ] gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() gm: GraphModule @@ -76,22 +87,26 @@ def mem_test_for_node_strategy(rank: int, # warmup with torch.no_grad(): - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) del output # forward memory compare if rank == 0: torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) if rank == 0: # print forward memory allocated and peak memory stats in kb @@ -113,8 +128,10 @@ def mem_test_for_node_strategy(rank: int, # estimated memory if target_node.op == "call_module": - metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + metainfo = ShardMetaInfo( + target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target), + ) else: metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) @@ -134,8 +151,16 @@ def mem_test_for_node_strategy(rank: int, print("=======================") -def print_results(input: List[torch.Tensor], output: List[torch.Tensor], compute_cost: TrainCycleItem, - memory_cost: TrainCycleItem, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak): +def print_results( + input: List[torch.Tensor], + output: List[torch.Tensor], + compute_cost: TrainCycleItem, + memory_cost: TrainCycleItem, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, +): """Print the results of the meta information test. Args: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index 80e6a6c1460c..73a15f3ba4de 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -13,7 +13,6 @@ class AddBMMTensorMethodModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -27,7 +26,6 @@ def forward(self, bias, x1, x2): class AddBMMTorchFunctionModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -42,7 +40,7 @@ def forward(self, bias, x1, x2): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module(using_kwargs).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -57,13 +55,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): # %bias : torch.Tensor [#users=1] = placeholder[target=bias] @@ -73,13 +73,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -96,49 +98,49 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] for name in strategy_name_list: print(name) # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -148,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -163,13 +165,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): @@ -180,13 +184,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -202,33 +208,33 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -237,11 +243,11 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_2d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -254,11 +260,11 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_1d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -270,6 +276,6 @@ def test_1d_device_mesh(module, bias_shape, using_kwargs): ) -if __name__ == '__main__': +if __name__ == "__main__": test_1d_device_mesh() test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index fe6554cd81ee..26f9c4ab1e3c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -19,7 +19,6 @@ class AddmmModel(nn.Module): - def __init__(self): super().__init__() @@ -29,7 +28,6 @@ def forward(self, input, m1, m2): class AddmmModel_with_param(nn.Module): - def __init__(self, weight_shape, bias_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(weight_shape)) @@ -42,7 +40,7 @@ def forward(self, m1): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if model_cls == AddmmModel: model = AddmmModel().cuda() else: @@ -58,10 +56,10 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [input, m1, m2] # construct meta arg names - meta_arg_names = ['input', 'm1', 'm2'] + meta_arg_names = ["input", "m1", "m2"] meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") # the index of addmm node in computation graph node_index = 4 @@ -72,22 +70,24 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [m1] # construct meta arg names - meta_arg_names = ['m1'] + meta_arg_names = ["m1"] # the index of addmm node in computation graph meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") node_index = 4 # strategy number of linear node strategy_number = 14 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -117,60 +117,60 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "m1" - assert mapping['input'].data.shape == torch.Size([4, 8]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8]) + assert mapping["input"].name == "m1" + assert mapping["input"].data.shape == torch.Size([4, 8]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8]) - assert mapping['other'].name == "transpose" - assert mapping['other'].data.shape == torch.Size([16, 8]) + assert mapping["other"].name == "transpose" + assert mapping["other"].data.shape == torch.Size([16, 8]) if model_cls == AddmmModel: - assert mapping['other'].type == OperationDataType.ARG + assert mapping["other"].type == OperationDataType.ARG else: - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([8, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([8, 16]) - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('m1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("m1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("transpose") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -178,14 +178,14 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('input_shape', [(16,), (4, 16)]) -@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) +@parameterize("input_shape", [(16,), (4, 16)]) +@parameterize("model_cls", [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape, model_cls): spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index c3ceef4c7adf..86df7237a219 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -16,7 +16,7 @@ def check_bn_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(16)).cuda() physical_mesh_id = torch.arange(0, 4) @@ -29,18 +29,20 @@ def check_bn_module_handler(rank, world_size, port): # the total number of bn strategies without sync bn mode # TODO: add sync bn strategies after related passes ready strategy_number = 4 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 16, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -59,37 +61,37 @@ def check_bn_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 64, 64]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RS = RS x S - assert 'RS0 = RS0 x S0' in strategy_name_list - assert 'RS1 = RS1 x S1' in strategy_name_list + assert "RS0 = RS0 x S0" in strategy_name_list + assert "RS1 = RS1 x S1" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # RS01 = RS01 x S01 - assert 'RS01 = RS01 x S01' in strategy_name_list + assert "RS01 = RS01 x S01" in strategy_name_list # temporarily skip the sync bn test # TODO: test sync bn after the implicit runtime pass completed @@ -105,12 +107,12 @@ def check_bn_module_handler(rank, world_size, port): # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bn_module_handler(): spawn(check_bn_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bn_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index 800bc11a50e4..e06625e1c42c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -5,7 +5,7 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -22,7 +22,6 @@ class LinearModule(torch.nn.Module): - def __init__(self, weight_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) @@ -35,7 +34,7 @@ def forward(self, x): def check_linear_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() physical_mesh_id = torch.arange(0, 4) @@ -49,14 +48,16 @@ def check_linear_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -66,7 +67,7 @@ def check_linear_module_handler(rank, world_size, port): # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) # return add - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -85,72 +86,72 @@ def check_linear_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -158,12 +159,12 @@ def check_linear_module_handler(rank, world_size, port): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): spawn(check_linear_module_handler) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index c29a065d10ba..690f0c12387c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -19,7 +19,6 @@ class LinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -31,7 +30,7 @@ def forward(self, x): def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(16, 32, bias=bias).cuda() physical_mesh_id = torch.arange(0, 4) @@ -45,17 +44,19 @@ def check_linear_module_handler(rank, world_size, port, bias): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -74,72 +75,72 @@ def check_linear_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "linear_weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "linear_weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("linear_weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -147,12 +148,12 @@ def check_linear_module_handler(rank, world_size, port, bias): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): spawn(check_linear_module_handler, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 83f3aafe220e..5b2e2ab49f6d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -16,10 +16,9 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") class BinaryElementwiseOpModel(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -41,16 +40,18 @@ def forward(self, x1, x2): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta"), "x2": torch.rand([4] * other_dim).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,23 +71,23 @@ def forward(self, x1, x2): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4] * other_dim) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 4]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4] * other_dim) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -95,19 +96,19 @@ def forward(self, x1, x2): assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output @@ -121,7 +122,6 @@ def forward(self, x1, x2): class BEOpModelWithNodeConst(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -133,7 +133,6 @@ def forward(self, x1): class BEOpModelWithIntConst(nn.Module): - def __init__(self, op, const): super().__init__() self.op = op @@ -146,7 +145,7 @@ def forward(self, x1): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -163,15 +162,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # construct input args input_args = [x1] # construct meta arg names - meta_arg_names = ['x1'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -188,17 +189,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -207,27 +208,27 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_tensor(op, other_dim): @@ -239,10 +240,10 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) -@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) +@parameterize("model_cls", [BEOpModelWithNodeConst, BEOpModelWithIntConst]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): @@ -255,6 +256,6 @@ def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): ) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_handler_with_tensor() test_binary_elementwise_handler_with_int() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index f4fdc458f80e..29df12832241 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -15,20 +15,18 @@ class BMMTensorMethodModule(nn.Module): - def forward(self, x1, x2): return x1.bmm(x2) class BMMTorchFunctionModule(nn.Module): - def forward(self, x1, x2): return torch.bmm(x1, x2) def check_2d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -42,15 +40,17 @@ def check_2d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,48 +70,48 @@ def check_2d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -121,7 +121,7 @@ def check_2d_device_mesh(rank, module, world_size, port): def check_1d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) @@ -135,15 +135,17 @@ def check_1d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -162,33 +164,33 @@ def check_1d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -196,9 +198,9 @@ def check_1d_device_mesh(rank, module, world_size, port): assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): @@ -206,5 +208,5 @@ def test_bmm_handler(module): spawn(check_1d_device_mesh, 4, module=module) -if __name__ == '__main__': +if __name__ == "__main__": test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index f9632b1cd8f9..8a37dd9256dd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -16,7 +16,7 @@ def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -32,14 +32,16 @@ def check_conv_module_handler(rank, world_size, port, bias): node_index = 1 # total number of conv strategies strategy_number = 16 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -58,76 +60,76 @@ def check_conv_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "weight" + assert mapping["other"].name == "weight" # assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" + assert mapping["bias"].name == "bias" # assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" # assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -141,7 +143,6 @@ def check_conv_module_handler(rank, world_size, port, bias): class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -152,7 +153,7 @@ def forward(self, input, others, bias=None): def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -160,22 +161,24 @@ def check_conv_function_handler(rank, world_size, port, bias): input = torch.rand(4, 4, 64, 64).cuda() others = torch.rand(16, 4, 3, 3).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of conv strategies strategy_number = 16 node_index = 2 if bias: bias_tensor = torch.rand(16).cuda() - input_kwargs['bias'] = bias_tensor + input_kwargs["bias"] = bias_tensor node_index += 1 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -183,9 +186,9 @@ def check_conv_function_handler(rank, world_size, port, bias): # %others : torch.Tensor [#users=1] = placeholder[target=others] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) # return conv2d - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta"), "others": torch.rand(16, 4, 3, 3).to("meta")} if bias: - meta_args['bias'] = torch.rand(16).to('meta') + meta_args["bias"] = torch.rand(16).to("meta") graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -208,76 +211,76 @@ def check_conv_function_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.is_meta + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "conv2d" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv2d" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("conv2d") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -290,7 +293,7 @@ def check_conv_function_handler(rank, world_size, port, bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -300,7 +303,7 @@ def test_conv_module_handler(bias=False): spawn(check_conv_module_handler, 4, bias=bias) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -310,6 +313,6 @@ def test_conv_function_handler(bias=False): spawn(check_conv_function_handler, 4, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_conv_module_handler() test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index 64f56ba98e2b..ce2ae4248fce 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -12,7 +12,6 @@ class ReshapeModel(nn.Module): - def __init__(self): super().__init__() @@ -22,7 +21,7 @@ def forward(self, input, other): return reshape_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_reshape_handler(): model = ReshapeModel() @@ -34,8 +33,8 @@ def test_reshape_handler(): # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), - "other": torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -50,14 +49,14 @@ def test_reshape_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = DefaultReshapeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + reshape_handler = DefaultReshapeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=reshape_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -69,20 +68,20 @@ def test_reshape_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 123008]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 123008]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(reshape_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_reshape_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index 4fa0313b1cb5..9ac6ba95da48 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -22,7 +22,6 @@ class EmbeddingModule(nn.Module): - def __init__(self, num_embeddings, embedding_dims): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dims) @@ -34,7 +33,7 @@ def forward(self, input): def check_embedding_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -51,15 +50,17 @@ def check_embedding_module_handler(rank, world_size, port): node_index = 1 # total number of embedding strategies strategy_number = 19 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')} + meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -78,60 +79,60 @@ def check_embedding_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] @@ -139,7 +140,6 @@ def check_embedding_module_handler(rank, world_size, port): class EmbeddingFunction(nn.Module): - def __init__(self): super().__init__() @@ -150,7 +150,7 @@ def forward(self, input, others): def check_embedding_function_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingFunction().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -159,18 +159,20 @@ def check_embedding_function_handler(rank, world_size, port): input = input.to(torch.int64).cuda() others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of embedding strategies strategy_number = 19 node_index = 2 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -178,8 +180,8 @@ def check_embedding_function_handler(rank, world_size, port): # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) # return embedding meta_args = { - "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'), - "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta"), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -189,9 +191,9 @@ def check_embedding_function_handler(rank, world_size, port): strategies_vector = StrategiesVector(embedding_node) # build handler - handler = EmbeddingFunctionHandler(node=embedding_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector) + handler = EmbeddingFunctionHandler( + node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector + ) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -202,82 +204,82 @@ def check_embedding_function_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_module_handler(): spawn(check_embedding_module_handler, 4) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_function_handler(): spawn(check_embedding_function_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_module_handler() test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index a089df743ec0..2c464f64d8ca 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -12,7 +12,6 @@ class GetattrModel(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) @@ -22,7 +21,7 @@ def forward(self, input): return weight -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_getattr_handler(): model = GetattrModel() @@ -31,7 +30,7 @@ def test_getattr_handler(): # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] # %conv_weight : [#users=1] = get_attr[target=conv.weight] # return conv_weight - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -42,9 +41,9 @@ def test_getattr_handler(): getattr_strategies_vector = StrategiesVector(getattr_node) # build handler - getattr_handler = GetattrHandler(node=getattr_node, - device_mesh=device_mesh, - strategies_vector=getattr_strategies_vector) + getattr_handler = GetattrHandler( + node=getattr_node, device_mesh=device_mesh, strategies_vector=getattr_strategies_vector + ) getattr_handler.register_strategy(compute_resharding_cost=False) @@ -56,20 +55,20 @@ def test_getattr_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "conv_weight" - assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv_weight" + assert mapping["output"].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in getattr_handler.strategies_vector] - assert 'get_attr [S0, S1, R, R]' in strategy_name_list - assert 'get_attr [S1, S0, R, R]' in strategy_name_list - assert 'get_attr [S01, R, R, R]' in strategy_name_list - assert 'get_attr [R, S01, R, R]' in strategy_name_list - assert 'get_attr [S0, R, R, R]' in strategy_name_list - assert 'get_attr [R, S0, R, R]' in strategy_name_list - assert 'get_attr [S1, R, R, R]' in strategy_name_list - assert 'get_attr [R, S1, R, R]' in strategy_name_list - assert 'get_attr [R, R, R, R]' in strategy_name_list + assert "get_attr [S0, S1, R, R]" in strategy_name_list + assert "get_attr [S1, S0, R, R]" in strategy_name_list + assert "get_attr [S01, R, R, R]" in strategy_name_list + assert "get_attr [R, S01, R, R]" in strategy_name_list + assert "get_attr [S0, R, R, R]" in strategy_name_list + assert "get_attr [R, S0, R, R]" in strategy_name_list + assert "get_attr [S1, R, R, R]" in strategy_name_list + assert "get_attr [R, S1, R, R]" in strategy_name_list + assert "get_attr [R, R, R, R]" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index a2e0968b18bb..cf802a228034 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -21,7 +19,6 @@ class GetItemFromTensorModel(nn.Module): - def __init__(self, getitem_index): super().__init__() self.getitem_index = getitem_index @@ -34,12 +31,12 @@ def forward(self, input, other): def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GetItemFromTensorModel(getitem_index=getitem_index) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -49,18 +46,20 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -72,14 +71,14 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): linear_strategies_vector = StrategiesVector(linear_mod_node) # build handler - linear_handler = LinearFunctionHandler(node=linear_mod_node, - device_mesh=device_mesh, - strategies_vector=linear_strategies_vector) + linear_handler = LinearFunctionHandler( + node=linear_mod_node, device_mesh=device_mesh, strategies_vector=linear_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_mod_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(linear_mod_node, "strategies_vector", linear_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_mod_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -94,17 +93,16 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): assert len(getitem_strategies_vector) == len(linear_strategies_vector) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) -@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +@parameterize("getitem_index", [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): spawn(check_getitem_from_tensor_handler, 4) class GetItemFromTupleModel(nn.Module): - def __init__(self): super().__init__() @@ -114,7 +112,7 @@ def forward(self, input): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() @@ -125,7 +123,7 @@ def test_getitem_from_tuple_handler(): # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) # return getitem meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -146,20 +144,20 @@ def test_getitem_from_tuple_handler(): node=input_node, device_mesh=device_mesh, strategies_vector=input_strategies_vector, - placeholder_option='replicated', + placeholder_option="replicated", ) input_handler.register_strategy(compute_resharding_cost=False) - setattr(input_node, 'strategies_vector', input_strategies_vector) - split_handler = DefaultReshapeHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + setattr(input_node, "strategies_vector", input_strategies_vector) + split_handler = DefaultReshapeHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) split_handler.register_strategy(compute_resharding_cost=False) - setattr(split_node, 'strategies_vector', split_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(split_node, "strategies_vector", split_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) - setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + setattr(getitem_node, "strategies_vector", getitem_strategies_vector) # check operation data mapping mapping = getitem_handler.get_operation_data_mapping() @@ -169,23 +167,23 @@ def test_getitem_from_tuple_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "split" - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + assert mapping["input"].name == "split" + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) - assert mapping['index'].name == "index" - assert isinstance(mapping['index'].data, int) - assert mapping['index'].type == OperationDataType.ARG + assert mapping["index"].name == "index" + assert isinstance(mapping["index"].data, int) + assert mapping["index"].type == OperationDataType.ARG - assert mapping['output'].name == "getitem" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "getitem" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(getitem_strategies_vector) == len(split_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_getitem_from_tensor_handler() test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index ad72c2026b9a..59a66bc6a5d6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -17,7 +17,7 @@ def check_ln_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.LayerNorm(16)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,19 +30,21 @@ def check_ln_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16).to('meta')} + meta_args = {"input": torch.rand(4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -62,45 +64,45 @@ def check_ln_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SR = SR x R - assert '[S0, R] = [S0, R] x [R]' in strategy_name_list - assert '[S1, R] = [S1, R] x [R]' in strategy_name_list + assert "[S0, R] = [S0, R] x [R]" in strategy_name_list + assert "[S1, R] = [S1, R] x [R]" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # S01R = S01R x R - assert '[S01, R] = [S01, R] x [R]' in strategy_name_list + assert "[S01, R] = [S01, R] x [R]" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_ln_module_handler(): spawn(check_ln_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index ec695cd8f7b9..da88b735f7c1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -23,7 +23,7 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -39,13 +39,15 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) meta_args = {"input": torch.rand(input_shape).cuda()} @@ -68,86 +70,86 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == input_logical_shape + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == input_logical_shape - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([32]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -159,7 +161,6 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -170,7 +171,7 @@ def forward(self, input, others, bias=None): def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -188,16 +189,18 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input, other] # construct meta arg names - meta_arg_names = ['input', 'others'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input", "others"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(input_shape).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -214,86 +217,86 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == torch.Size(input_logical_shape) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == torch.Size(input_logical_shape) - assert mapping['other'].name == "others" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "others" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert mapping['output'].name == "linear" + assert mapping["output"].name == "linear" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -304,8 +307,8 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("input_shape", [(1, 4, 4, 16), (4, 4, 4, 16)]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): @@ -323,5 +326,5 @@ def test_linear_handler(input_shape, bias=False): ) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 938acd3d1eea..5fb4985e2f3c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -22,31 +22,31 @@ class MatMulModule(nn.Module): - def forward(self, x1, x2): return torch.matmul(x1, x2) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[8], [8]], # dot product - [[4, 8], [8]], # mat-vec product - [[4, 8], [8, 16]], # mat-mat product - [[8], [8, 16]], # mat-mat product - [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting - [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting - ]) + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ], +) def test_matmul_node_handler(tensor_shapes): input_shape, other_shape = tensor_shapes @@ -61,7 +61,7 @@ def test_matmul_node_handler(tensor_shapes): model = MatMulModule() tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')} + meta_args = {"x1": x1.to("meta"), "x2": x2.to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -92,30 +92,31 @@ def test_matmul_node_handler(tensor_shapes): logical_input_shape = [1] + input_shape elif matmul_type == MatMulType.BMM: logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( - input_shape, other_shape, handler.transforms) + input_shape, other_shape, handler.transforms + ) else: logical_input_shape = input_shape # check input operation data - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size(logical_input_shape) # check other operation data - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size(other_shape) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size(other_shape) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size(logical_other_shape) # check output - assert mapping['output'].name == "matmul" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + assert mapping["output"].name == "matmul" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size(logical_output_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -126,9 +127,9 @@ def test_matmul_node_handler(tensor_shapes): for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("matmul") if matmul_type == MatMulType.DOT: # dot product will produce a scaler # results should fulfill: @@ -171,5 +172,5 @@ def test_matmul_node_handler(tensor_shapes): assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 6bff9f9648e2..6b7ac766ff18 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -10,16 +10,16 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_norm_pool_handler(): - model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to("meta")) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -41,21 +41,21 @@ def test_norm_pool_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 9 -if __name__ == '__main__': +if __name__ == "__main__": test_norm_pool_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 1703d5ded2f2..4da986181f89 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -12,7 +12,6 @@ class OutputModel(nn.Module): - def __init__(self): super().__init__() @@ -21,8 +20,8 @@ def forward(self, x): return x, y -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('output_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("output_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_output_handler(output_option): model = OutputModel() @@ -31,7 +30,7 @@ def test_output_handler(output_option): # %x : torch.Tensor [#users=2] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # return (x, mul) - meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -43,10 +42,12 @@ def test_output_handler(output_option): output_strategies_vector = StrategiesVector(output_node) # build handler - output_handler = OutputHandler(node=output_node, - device_mesh=device_mesh, - strategies_vector=output_strategies_vector, - output_option=output_option) + output_handler = OutputHandler( + node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option, + ) output_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -57,14 +58,14 @@ def test_output_handler(output_option): # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "output" - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "output" + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in output_handler.strategies_vector] - if output_option == 'distributed': + if output_option == "distributed": assert "Distributed Output" in strategy_name_list else: assert "Replica Output" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_output_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index f071cd120fb7..958dc288fa16 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -20,7 +18,6 @@ class ConvReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -37,7 +34,6 @@ def forward(self, input, other): class LinearReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -55,23 +51,23 @@ def forward(self, input, other): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if call_function == torch.permute: reshape_dims = reshape_dims[0] elif call_function == torch.transpose: reshape_dims = reshape_dims[1] model = model_cls(reshape_dims, call_function).cuda() - if model_cls.__name__ == 'ConvReshapeModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvReshapeModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearReshapeModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearReshapeModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -81,15 +77,17 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvReshapeModel': + if model_cls.__name__ == "ConvReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -97,12 +95,12 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -110,8 +108,8 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -124,30 +122,29 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvReshapeModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvReshapeModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) if call_function == torch.permute: - reshape_handler = PermuteHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = PermuteHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) else: - reshape_handler = TransposeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = TransposeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -159,25 +156,25 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvReshapeModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvReshapeModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) if call_function == torch.permute: - assert mapping['output'].name == "permute" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "permute" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT else: - assert mapping['output'].name == "transpose" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "transpose" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) @@ -185,146 +182,144 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode if rank == 0: for name in strategy_name_list: print(name) - if model_cls.__name__ == 'ConvReshapeModel': - + if model_cls.__name__ == "ConvReshapeModel": if reshape_dims in ((0, 2, 1, 3), (1, 2)): - assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [S0, R, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [R, S0, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [R, S1, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearReshapeModel': - + assert "[S0, S1, R, R] -> [S0, R, R, S1]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, R, S0]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_15" in strategy_name_list + + if model_cls.__name__ == "LinearReshapeModel": if reshape_dims == ((0, 2, 1, 3), (1, 2)): - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [R, S0, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [S0, R, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [R, S1, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [S1, R, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [S0, R, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [S1, R, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [S01, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, S1, R, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S1, R, S0]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S1, S0, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, S0, R, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S0, R, S1]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S0, S1, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01, R, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('call_function', [torch.permute, torch.transpose]) -@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) -@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +@parameterize("call_function", [torch.permute, torch.transpose]) +@parameterize("reshape_dims", [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize("model_cls", [ConvReshapeModel, LinearReshapeModel]) def test_view_handler(call_function, reshape_dims, model_cls): spawn( check_view_handler, @@ -335,5 +330,5 @@ def test_view_handler(call_function, reshape_dims, model_cls): ) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 6d02b0e0ba74..60c090429c6c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -12,7 +12,6 @@ class PlaceholderModel(nn.Module): - def __init__(self): super().__init__() @@ -20,8 +19,8 @@ def forward(self, input): return input -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('placeholder_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("placeholder_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_placeholder_handler(placeholder_option): model = PlaceholderModel() @@ -30,7 +29,7 @@ def test_placeholder_handler(placeholder_option): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # return input_1 meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -42,10 +41,12 @@ def test_placeholder_handler(placeholder_option): placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) # build handler - placeholder_handler = PlaceholderHandler(node=placeholder_node, - device_mesh=device_mesh, - strategies_vector=placeholder_strategies_vector, - placeholder_option=placeholder_option) + placeholder_handler = PlaceholderHandler( + node=placeholder_node, + device_mesh=device_mesh, + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option, + ) placeholder_handler.register_strategy(compute_resharding_cost=False) @@ -53,28 +54,28 @@ def test_placeholder_handler(placeholder_option): mapping = placeholder_handler.get_operation_data_mapping() strategy = placeholder_strategies_vector[0] - strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping["output"].name) - if placeholder_option == 'distributed': - assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + if placeholder_option == "distributed": + assert str(strategy_sharding_spec.sharding_sequence) == "[S01, R, R, R]" else: - assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + assert str(strategy_sharding_spec.sharding_sequence) == "[R, R, R, R]" for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "input_1" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "input_1" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size((4, 4, 64, 64)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] - if placeholder_option == 'replicated': + if placeholder_option == "replicated": assert "Replica Placeholder" in strategy_name_list else: assert "Distributed Placeholder" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_placeholder_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index 14c364c45fc4..6836a882242f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -12,7 +12,6 @@ class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -28,7 +27,7 @@ def check_shard_option(shard_option): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(4, 4, 4, 16).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -36,77 +35,76 @@ def check_shard_option(shard_option): strategies_vector = StrategiesVector(linear_func_node) # build handler - handler = LinearFunctionHandler(node=linear_func_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shard_option=shard_option) + handler = LinearFunctionHandler( + node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector, shard_option=shard_option + ) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] if shard_option == ShardOption.SHARD_LAST_AXIS: # RR = RS x SR - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list return # SS = SR x RS - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list if shard_option == ShardOption.SHARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list if shard_option == ShardOption.STANDARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_shard_option(): # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: @@ -114,5 +112,5 @@ def test_shard_option(): check_shard_option(shard_option) -if __name__ == '__main__': +if __name__ == "__main__": test_shard_option() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index 75ae0416ef98..1a99c32ebcb9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -17,7 +17,6 @@ class LinearSplitModel(nn.Module): - def __init__(self, softmax_dim): super().__init__() self.softmax_dim = softmax_dim @@ -30,11 +29,11 @@ def forward(self, input, other): def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(softmax_dim=softmax_dim).cuda() - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -44,13 +43,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -60,8 +61,8 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -75,15 +76,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - softmax_handler = SoftmaxHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + softmax_handler = SoftmaxHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) softmax_handler.register_strategy(compute_resharding_cost=False) @@ -95,84 +96,84 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['softmax_dim'].name == "softmax_dim" - assert mapping['softmax_dim'].data == softmax_dim - assert mapping['softmax_dim'].type == OperationDataType.ARG + assert mapping["softmax_dim"].name == "softmax_dim" + assert mapping["softmax_dim"].data == softmax_dim + assert mapping["softmax_dim"].type == OperationDataType.ARG - assert mapping['output'].name == "softmax" - assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "softmax" + assert mapping["output"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] if softmax_dim == 0: - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if softmax_dim == 1: - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('softmax_dim', [0, 1, 2, 3]) -@parameterize('model_cls', [LinearSplitModel]) +@parameterize("softmax_dim", [0, 1, 2, 3]) +@parameterize("model_cls", [LinearSplitModel]) def test_split_handler(softmax_dim, model_cls): spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index f860c629b0a0..0318023c858d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -17,7 +17,6 @@ class ConvSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -30,7 +29,6 @@ def forward(self, input, other): class LinearSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -44,19 +42,19 @@ def forward(self, input, other): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(split_size=split_size, split_dim=split_dim).cuda() - if model_cls.__name__ == 'ConvSplitModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvSplitModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearSplitModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearSplitModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -66,15 +64,17 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvSplitModel': + if model_cls.__name__ == "ConvSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -82,12 +82,12 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -95,8 +95,8 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -109,21 +109,20 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvSplitModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvSplitModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) @@ -137,124 +136,122 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvSplitModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvSplitModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "split" + assert mapping["output"].name == "split" split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) - assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == tuple([item.shape for item in split_items]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] - if model_cls.__name__ == 'ConvSplitModel': - + if model_cls.__name__ == "ConvSplitModel": if split_dim == 0: - assert '[R, S1, R, R]_0' in strategy_name_list - assert '[R, S0, R, R]_1' in strategy_name_list - assert '[R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, R]_4' in strategy_name_list - assert '[R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R]_6' in strategy_name_list - assert '[R, S0, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R]_10' in strategy_name_list - assert '[R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R]_15' in strategy_name_list + assert "[R, S1, R, R]_0" in strategy_name_list + assert "[R, S0, R, R]_1" in strategy_name_list + assert "[R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, R]_4" in strategy_name_list + assert "[R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R]_6" in strategy_name_list + assert "[R, S0, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R]_10" in strategy_name_list + assert "[R, S1, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R]_15" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, R]_0' in strategy_name_list - assert '[S1, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R]_5' in strategy_name_list - assert '[R, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R]_10' in strategy_name_list - assert '[R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearSplitModel': - + assert "[S0, R, R, R]_0" in strategy_name_list + assert "[S1, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R]_5" in strategy_name_list + assert "[R, R, R, R]_6" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, R, R, R]_10" in strategy_name_list + assert "[R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearSplitModel": if split_dim == 0: - assert '[R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('split_size', [2]) -@parameterize('split_dim', [0, 1, 2]) -@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +@parameterize("split_size", [2]) +@parameterize("split_dim", [0, 1, 2]) +@parameterize("model_cls", [ConvSplitModel, LinearSplitModel]) def test_split_handler(split_size, split_dim, model_cls): spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index c11291ecac96..cbd3e47044b3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -16,7 +16,6 @@ class LinearSumModel(nn.Module): - def __init__(self, sum_dims, keepdim): super().__init__() self.sum_dims = sum_dims @@ -33,26 +32,28 @@ def forward(self, input, other): def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies strategy_number = 24 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) @@ -63,8 +64,8 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) # return sum_1 meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -78,11 +79,11 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) @@ -100,131 +101,131 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "sum_1" + assert mapping["output"].name == "sum_1" sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape - assert mapping['output'].logical_shape == sum_node_shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == sum_node_shape + assert mapping["output"].type == OperationDataType.OUTPUT # check strategy name if sum_dims == (0, 2) and keepdim == False: - assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [S01, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [S0, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [S1, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [S0, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [S1, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_23" in strategy_name_list if sum_dims == (0, 2) and keepdim == True: - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == False: - assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == True: - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('sum_dims', [(0, 2), 1]) -@parameterize('keepdim', [False, True]) +@parameterize("sum_dims", [(0, 2), 1]) +@parameterize("keepdim", [False, True]) def test_sum_handler(sum_dims, keepdim): spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) -if __name__ == '__main__': +if __name__ == "__main__": test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index 5b6ac051a8ef..29089183165d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -11,7 +11,6 @@ class TensorConstructorModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ def forward(self, x): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_where_handler(): model = TensorConstructorModel() @@ -33,7 +32,7 @@ def test_where_handler(): # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) # return add - meta_args = {'x': torch.rand(10).to('meta')} + meta_args = {"x": torch.rand(10).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -56,16 +55,16 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['output'].name == "arange" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([10]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "arange" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([10]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - assert 'Replica Tensor Constructor' in strategy_name_list + assert "Replica Tensor Constructor" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index f4e6dafdfd69..271d55ae917a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -12,7 +12,6 @@ class ReLuModel(nn.Module): - def __init__(self): super().__init__() self.act = torch.nn.ReLU() @@ -23,7 +22,7 @@ def forward(self, input, other): return relu_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_elementwise_handler(): model = ReLuModel() @@ -35,8 +34,8 @@ def test_elementwise_handler(): # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) # return act meta_args = { - 'input': torch.rand(4, 4, 64, 64).to('meta'), - 'other': torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -51,14 +50,14 @@ def test_elementwise_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - relu_handler = UnaryElementwiseHandler(node=relu_mod_node, - device_mesh=device_mesh, - strategies_vector=relu_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + relu_handler = UnaryElementwiseHandler( + node=relu_mod_node, device_mesh=device_mesh, strategies_vector=relu_strategies_vector + ) relu_handler.register_strategy(compute_resharding_cost=False) @@ -70,20 +69,20 @@ def test_elementwise_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "act" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "act" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(relu_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index fbb194d8e0b8..466168c79a0b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -18,7 +18,6 @@ class ConvViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -30,7 +29,6 @@ def forward(self, input, other): class LinearViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -43,19 +41,19 @@ def forward(self, input, other): def check_view_handler(rank, tgt_shape, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(tgt_shape).cuda() - if model_cls.__name__ == 'ConvViewModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvViewModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearViewModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearViewModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -65,25 +63,27 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvViewModel': + if model_cls.__name__ == "ConvViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view - meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')} + meta_args = {"input": torch.rand(8, 8, 66, 66).to("meta"), "other": torch.rand(16, 8, 3, 3).to("meta")} graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -91,8 +91,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return view meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -105,21 +105,20 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvViewModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvViewModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) @@ -133,126 +132,124 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvViewModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvViewModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(tgt_shape) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(tgt_shape) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in view_strategies_vector] - if model_cls.__name__ == 'ConvViewModel': - + if model_cls.__name__ == "ConvViewModel": if tgt_shape == (32, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list - assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + assert "[S0, S1, R, R] -> FULLY REPLICATED_0" in strategy_name_list + assert "[S1, S0, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_6" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_10" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_15" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearViewModel': - + assert "[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearViewModel": if tgt_shape == (32, 4, 64, 16, 4): for strategy in strategy_name_list: print(strategy) # print(strategy_name_list) - assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> FULLY REPLICATED_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> FULLY REPLICATED_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01, R]_4" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, R, S01, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) -@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +@parameterize("tgt_shape", [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize("model_cls", [ConvViewModel, LinearViewModel]) def test_view_handler(tgt_shape, model_cls): spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index bd7635ac1737..10ca644cddc2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -12,7 +12,6 @@ class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ def forward(self, condition, x, y): return output -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_where_handler(): model = ConvModel() @@ -33,9 +32,9 @@ def test_where_handler(): # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) # return where meta_args = { - 'condition': torch.rand(4, 4, 64, 64).to('meta'), - 'x': torch.rand(4, 1, 64, 64).to('meta'), - 'y': torch.rand(1, 4, 64, 64).to('meta') + "condition": torch.rand(4, 4, 64, 64).to("meta"), + "x": torch.rand(4, 1, 64, 64).to("meta"), + "y": torch.rand(1, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -59,28 +58,28 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['condition'].name == "condition" - assert mapping['condition'].data.is_meta - assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['condition'].type == OperationDataType.ARG - assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['x'].name == "x" - assert mapping['x'].data.is_meta - assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) - assert mapping['x'].type == OperationDataType.ARG - assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['y'].name == "y" - assert mapping['y'].data.is_meta - assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) - assert mapping['y'].type == OperationDataType.ARG - assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['output'].name == "where" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["condition"].name == "condition" + assert mapping["condition"].data.is_meta + assert mapping["condition"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["condition"].type == OperationDataType.ARG + assert mapping["condition"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["x"].name == "x" + assert mapping["x"].data.is_meta + assert mapping["x"].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping["x"].type == OperationDataType.ARG + assert mapping["x"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["y"].name == "y" + assert mapping["y"].data.is_meta + assert mapping["y"].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping["y"].type == OperationDataType.ARG + assert mapping["y"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["output"].name == "where" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -88,5 +87,5 @@ def test_where_handler(): assert len(strategy_name_list) == 25 -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index 28a8bbd9a4c1..3591c663897c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -2,7 +2,6 @@ from typing import Dict, List import torch -from torch.fx import GraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass @@ -18,16 +17,18 @@ from colossalai.testing.comparison import assert_close -def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], - input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): - +def _build_model_to_compare( + model: torch.nn.Module, + input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], + grad_dict: Dict[any, torch.Tensor], +): model_to_compare = copy.deepcopy(model) args_to_compare = [] kwargs_to_compare = {} for arg_index, input_tensor in enumerate(input_args): def wrapper(param, index): - def hook_fn(grad): grad_dict[index] = grad @@ -45,7 +46,6 @@ def hook_fn(grad): for name, input_kwarg in input_kwargs.items(): def wrapper(param, name): - def hook_fn(grad): grad_dict[name] = grad @@ -63,30 +63,34 @@ def hook_fn(grad): return model_to_compare, args_to_compare, kwargs_to_compare -def numerical_test_for_node_strategy(model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}, - node_type: str = 'normal'): +def numerical_test_for_node_strategy( + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = "normal", +): for strategy_index in range(strategy_number): - print(f'#strategy_index: {strategy_index}') + print(f"#strategy_index: {strategy_index}") # We need to copy the model to avoid do backward more than once in same graph grad_to_compare_dict = {} grad_to_shard_dict = {} model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( - model, input_args, input_kwargs, grad_to_compare_dict) - model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, - grad_to_shard_dict) + model, input_args, input_kwargs, grad_to_compare_dict + ) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_shard_dict + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta') + input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta') + input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -94,13 +98,14 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() - target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies - ][node_index] - if node_type == 'normal': + target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies][ + node_index + ] + if node_type == "normal": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index - elif node_type == 'following': + elif node_type == "following": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index @@ -116,18 +121,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() # forward result compare - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) - assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type="forward output") # backward result compare if isinstance(output, (tuple, list)): @@ -142,43 +150,45 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, for key in grad_to_shard_dict.keys(): grad_to_shard = grad_to_shard_dict[key] grad_to_compare = grad_to_compare_dict[key] - assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type="input grad") # extract the strategy used in this iter strategy_in_use = target_node.strategies_vector[strategy_index] param_to_shard_dict = dict(gm.named_parameters()) param_to_compare_dict = dict(model_to_compare.named_parameters()) for name in param_to_shard_dict.keys(): - param_name = name.split('.')[-1] - if node_type == 'normal': + param_name = name.split(".")[-1] + if node_type == "normal": param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) else: - if 'weight' in name: + if "weight" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'weight' in node.name: + if "weight" in node.name: param_sharding_spec = node.sharding_spec - elif 'bias' in name: + elif "bias" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'bias' in node.name: + if "bias" in node.name: param_sharding_spec = node.sharding_spec assert param_sharding_spec is not None grad_sharded = param_to_shard_dict[name].grad grad_to_compare = param_to_compare_dict[name].grad global_grad = to_global(grad_sharded, param_sharding_spec) - assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type="param grad") -def assert_close_helper(first: torch.Tensor, - second: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2, - strategy_index: int = -1, - type: str = 'not defined'): +def assert_close_helper( + first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = "not defined", +): """ This method is used to check whether the average difference between two tensors is as close as expected. """ @@ -189,4 +199,4 @@ def assert_close_helper(first: torch.Tensor, else: assert_close(first, second, rtol=rtol, atol=atol) except: - print(f'strategy index {strategy_index} encounter assert_close error on {type}') + print(f"strategy index {strategy_index} encounter assert_close error on {type}") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 0d93e4e40527..e7b8c696e62e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,17 +3,18 @@ from torchvision.models import resnet50 from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) @@ -21,11 +22,11 @@ def test_cost_graph(): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) model = resnet50(num_classes=100000) - input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + input_sample = {"x": torch.rand(128, 3, 224, 224).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) # graph(): @@ -74,7 +75,7 @@ def test_cost_graph(): communication_cost_bn = 0 memory_cost = 0 for index, node in enumerate(graph.nodes): - if node.op == 'call_module': + if node.op == "call_module": submod = node.graph.owning_module.get_submodule(node.target) if type(submod) in BATCHNORM_MODULE_OP: communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total @@ -86,11 +87,11 @@ def test_cost_graph(): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - print(f'bn communication cost is {communication_cost_bn}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") + print(f"bn communication cost is {communication_cost_bn}") -if __name__ == '__main__': +if __name__ == "__main__": test_cost_graph() diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index d07145e48e1f..07fd0ad582e9 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -111,13 +111,14 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_evoformer_stack(data_args): from test_autochunk_evoformer_stack import get_data, get_model + print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1])) max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data) for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]: try: _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index 593658fd1368..3d3f212a68d0 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -8,7 +8,6 @@ from colossalai.autochunk.utils import flat_list from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.legacy.core import global_context as gpc from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: @@ -80,9 +79,9 @@ def assert_codegen_run( out_gm = flat_list(out_gm) out_model = flat_list(out_model) for out_gm_i, out_model_i in zip(out_gm, out_model): - assert torch.allclose(out_gm_i, out_model_i, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm_i - out_model_i)) + assert torch.allclose( + out_gm_i, out_model_i, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm_i - out_model_i)) return chunks diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 9e4cb7ee9f95..1a4ababda30d 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True except: HAS_REPO = False @@ -17,22 +18,26 @@ def get_model(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -54,8 +59,20 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184), - (140, 145), (162, 163), (203, 204)], + None: [ + (120, 126), + (225, 244), + (270, 289), + (306, 311), + (70, 106), + (23, 46), + (146, 152), + (187, 193), + (181, 184), + (140, 145), + (162, 163), + (203, 204), + ], 20: [(120, 123), (232, 237), (277, 282), (305, 306)], 24: [(122, 123)], } diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 6b47033e199f..0b04ba5257b6 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True except: HAS_REPO = False @@ -17,26 +18,30 @@ def get_model(): - model = EvoformerStack( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - c_s=384, - no_heads_msa=8, - no_heads_pair=4, - no_blocks=2, # 48 - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.25, - blocks_per_ckpt=None, - inf=1000000000.0, - eps=1e-08, - clear_cache_between_blocks=False, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -62,7 +67,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_evoformer_stack(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index b4c577c18ee6..585a9e3381c4 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Tuple import pytest import torch @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True except: HAS_REPO = False @@ -16,23 +17,27 @@ def get_model(): - model = ExtraMSABlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - ckpt=False, - is_multimer=False, - ).eval().cuda() + model = ( + ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -58,7 +63,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_extramsa_block(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py index 6fb7efa7a8fc..b75cbe67590c 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py +++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_unet_gm( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_unet_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_unet_origin( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_unet(batch=1, height=448, width=448): from test_autochunk_unet import UNet2DModel, get_data + model = UNet2DModel() latent_shape = (batch, 3, height // 7, width // 7) @@ -124,7 +129,7 @@ def benchmark_autochunk_unet(batch=1, height=448, width=448): try: _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 264331a5fef0..32034992090f 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -83,9 +83,11 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm["sample"], out_model["sample"], - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm["sample"] - out_model["sample"])) + assert torch.allclose( + out_gm["sample"], out_model["sample"], atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm["sample"] - out_model["sample"]) + ) return chunks @@ -129,7 +131,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index f0cf2a5fcbca..ad50874c92a3 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -5,9 +5,11 @@ try: import diffusers + MODELS = [diffusers.UNet2DModel] HAS_REPO = True from packaging import version + SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2") except: MODELS = [] diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py index 63490aaee7ff..e70e50175032 100644 --- a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_gpt_gm( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_gpt_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_gpt_origin( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data + model = GPT2Model config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head) model = model(config=config) @@ -125,7 +130,7 @@ def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): try: _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 82af6c05c6ef..b2d842ee6a7b 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -5,6 +5,7 @@ try: from transformers import GPT2Config, GPT2Model + MODELS = [GPT2Model] HAS_REPO = True except: @@ -52,13 +53,15 @@ def test_autochunk_gpt(model, shape, max_memory): if __name__ == "__main__": - run_test(rank=0, - data=get_data((BATCH_SIZE, SEQ_LENGTH)), - max_memory=None, - model=GPT2Model, - config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), - print_code=False, - print_est_mem=False, - print_mem=False, - print_progress=False, - eval_mem=False) + run_test( + rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=False, + print_est_mem=False, + print_mem=False, + print_progress=False, + eval_mem=False, + ) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index 5c863b0df47f..77c11db71a5c 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -38,11 +38,9 @@ def assert_codegen_run( meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] interp.propagate(*meta_tensors) - codegen = AutoChunkCodeGen(meta_graph, - max_memory=max_memory, - print_mem=print_est_mem, - print_progress=print_progress, - eval_mem=eval_mem) + codegen = AutoChunkCodeGen( + meta_graph, max_memory=max_memory, print_mem=print_est_mem, print_progress=print_progress, eval_mem=eval_mem + ) chunks = codegen.chunk_infos # trace and recompile @@ -85,9 +83,9 @@ def assert_allclose(out_model: Any, out_gm: Any) -> None: assert allclose for out """ if isinstance(out_model, torch.Tensor): - assert torch.allclose(out_model, out_gm, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_model - out_gm)) + assert torch.allclose( + out_model, out_gm, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_model - out_gm)) elif isinstance(out_model, dict): for k in out_model.keys(): assert_allclose(out_model[k], out_gm[k]) @@ -123,19 +121,21 @@ def run_test( ) # build model and input - chunks = assert_codegen_run(model, - data=data, - max_memory=max_memory, - print_code=print_code, - print_est_mem=print_est_mem, - print_mem=print_mem, - print_progress=print_progress, - eval_mem=eval_mem) + chunks = assert_codegen_run( + model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_est_mem=print_est_mem, + print_mem=print_mem, + print_progress=print_progress, + eval_mem=eval_mem, + ) if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py index a98aa0e03954..aa868d683f06 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -5,6 +5,7 @@ try: from timm.models.vision_transformer import vit_large_patch16_384 as vit + MODELS = [vit] HAS_REPO = True except: @@ -19,7 +20,7 @@ def get_data() -> Tuple[List, List]: data = torch.rand(1, 3, 384, 384) - meta_args = {'x': data} + meta_args = {"x": data} return data, meta_args diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 65d1e9c4d090..ca919fb7e4fe 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -75,9 +75,9 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm, out_model, - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm - out_model)) + assert torch.allclose( + out_gm, out_model, atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm - out_model)) return chunks @@ -121,7 +121,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 6f3f66ed41b8..777589299d13 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -5,7 +5,7 @@ @clear_cache_before_run() -@parameterize('device', ['cpu', 'cuda']) +@parameterize("device", ["cpu", "cuda"]) def test_accelerator(device): accelerator = Accelerator(device) model = nn.Linear(8, 8) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 26ce00e94869..3aefb37974f0 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -9,11 +9,11 @@ def run_torch_amp(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - sub_model_zoo = model_zoo.get_sub_registry('timm') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + sub_model_zoo = model_zoo.get_sub_registry("timm") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue model = model_fn().cuda() @@ -21,7 +21,7 @@ def run_torch_amp(rank, world_size, port): criterion = lambda x: x.mean() data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } mixed_precision = FP16TorchMixedPrecision() model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index a58afac810d7..ad878fb0c86a 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -16,11 +16,11 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'lazy': + if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16') + plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision="bf16") booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -29,7 +29,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ data = data_gen_fn() data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + k: v.to("cuda").repeat(4, 1) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } @@ -50,23 +50,24 @@ def _criterion(outputs, inputs): return repr(e) -@parameterize('init_method', ['none', 'lazy']) -def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): +@parameterize("init_method", ["none", "lazy"]) +def check_3d_plugin(init_method: str = "none", early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ is_support_meta = is_compatible_with_meta() - if not is_support_meta and init_method == 'lazy': + if not is_support_meta and init_method == "lazy": return passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair # TODO(ver217): add more models - for name, (model_fn, data_gen_fn, output_transform_fn, _, - _) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( + "transformers_llama_for_casual_lm" + ).items(): err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -78,15 +79,15 @@ def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Init method: {init_method}') - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) @@ -95,5 +96,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index 689b334cae50..0ac9d0f6d409 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -15,8 +15,7 @@ class DPPluginWrapper(DPPluginBase): - """This is a wrapper class for testing DP plugin initialization and dataloader creation. - """ + """This is a wrapper class for testing DP plugin initialization and dataloader creation.""" def configure( self, @@ -73,13 +72,14 @@ def check_dataloader_sharding(): # compare on rank 0 if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' + assert not torch.equal( + batch, batch_to_compare + ), "Same number was found across ranks but expected it to be different" def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_dataloader_sharding() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 18be68bf6e48..00ff6cb37d2a 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -17,7 +17,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'lazy': + if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() @@ -30,13 +30,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) for n, p in model.named_parameters(): - assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter" output = model(**data) output = output_transform_fn(output) @@ -55,47 +55,65 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ # @parameterize('init_method', ['lazy', 'none', 'colo']) -@parameterize('subset', ['torchvision', 'transformers', 'diffusers']) -@parameterize('init_method', ['none']) -def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True): +@parameterize("subset", ["torchvision", "transformers", "diffusers"]) +@parameterize("init_method", ["none"]) +def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ is_support_meta = is_compatible_with_meta() - if not is_support_meta and init_method == 'lazy': + if not is_support_meta and init_method == "lazy": return passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items(): # These models lead to CUDA error - if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', - 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext', - 'torchvision_convnext_base'): + if name in ( + "diffusers_auto_encoder_kl", + "diffusers_vq_model", + "diffusers_unet2d_model", + "timm_resmlp", + "timm_gmixer_12_224", + "timm_gmlp_b16_224", + "timm_mixer_b16_224", + "timm_convnext", + "torchvision_convnext_base", + ): continue # These models are not compatible with gemini if name in [ - 'timm_convit', - 'timm_dm_nfnet', - 'torchvision_vit_b_16', - 'transformers_t5', - 'transformers_t5_for_conditional_generation', - 'transformers_t5_encoder_model', # does not support apex rmsnorm - 'transformers_chatglm', - 'transformers_sam', - 'transformers_vit', - 'transformers_gpt_double_heads', # TODO check why does the model fail to run using Gemini + "timm_convit", + "timm_dm_nfnet", + "torchvision_vit_b_16", + "transformers_t5", + "transformers_t5_for_conditional_generation", + "transformers_t5_encoder_model", # does not support apex rmsnorm + "transformers_chatglm", + "transformers_sam", + "transformers_vit", + "transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini ]: continue - if init_method == 'lazy' and name in [ - 'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3', - 'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0', - 'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf', - 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' + if init_method == "lazy" and name in [ + "timm_convmixer", + "timm_vision_transformer", + "timm_deit", + "timm_deit3", + "timm_inception_v3", + "timm_tnt_b_patch16_224", + "timm_rexnet", + "torchvision_densenet121", + "torchvision_efficientnet_b0", + "torchvision_mobilenet_v2", + "torchvision_mnasnet0_5", + "torchvision_regnet_x_16gf", + "torchvision_shufflenet_v2_x0_5", + "torchvision_efficientnet_v2_s", ]: continue err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) @@ -108,15 +126,15 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool break if dist.get_rank() == 0: - print(f'Init method: {init_method}') - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_gemini_plugin(early_stop=early_stop) @@ -125,5 +143,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 79f98a4c95d0..9cc12f96bd4d 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,9 +11,9 @@ from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch'] +_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -26,7 +26,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -43,7 +43,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: return repr(e) -@parameterize('stage', [2]) +@parameterize("stage", [2]) def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """check low level zero plugin over model zoo @@ -52,7 +52,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS skipped_models = [] @@ -73,15 +73,15 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + print(f"Skipped models({len(skipped_models)}): {skipped_models}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) @@ -90,5 +90,5 @@ def test_low_level_zero_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_low_level_zero_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 23d743c924aa..1a7ca6f2a30c 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -22,7 +22,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): criterion = lambda x: x.mean() data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -41,14 +41,13 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() class DummyModel(nn.Module): - def __init__(self): super().__init__() self.weight = nn.Parameter(torch.rand(1)) @@ -67,10 +66,9 @@ def check_torch_ddp_no_sync(): # create a custom dataset with 0 to 10 dataset = torch.arange(0, 10) train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) - model, optimizer, criterion, train_dataloader, _ = booster.boost(model, - optimizer, - criterion, - dataloader=train_dataloader) + model, optimizer, criterion, train_dataloader, _ = booster.boost( + model, optimizer, criterion, dataloader=train_dataloader + ) def fwd_bwd(): output = model(batch.cuda()) @@ -105,7 +103,7 @@ def get_grad_set_over_all_ranks(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_plugin() check_torch_ddp_no_sync() diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index e09ad766bb32..8bcbffdd06fe 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -6,7 +6,7 @@ import colossalai from colossalai.booster import Booster -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from colossalai.booster.plugin import TorchFSDPPlugin @@ -24,7 +24,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): criterion = lambda x: x.mean() data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -43,10 +43,16 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): - if any(element in name for element in [ - 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', - 'torchvision_inception_v3' - ]): + if any( + element in name + for element in [ + "diffusers", + "deepfm_sparsearch", + "dlrm_interactionarch", + "torchvision_googlenet", + "torchvision_inception_v3", + ] + ): continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -54,11 +60,11 @@ def check_torch_fsdp_plugin(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_plugin() -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_plugin(): spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 6720be58490b..d66dec113017 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -19,50 +19,30 @@ from tests.kit.model_zoo import model_zoo MODEL_PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half ] OPTIM_PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half ] @clear_cache_before_run() -@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS) -@parameterize('model_name', ['transformers_bert_for_sequence_classification']) -@parameterize('use_safetensors', [False, True]) +@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) +@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) +@parameterize("use_safetensors", [False, True]) def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() with shared_tempdir() as tempdir: - pretrained_path = os.path.join(tempdir, 'pretrained') + pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) plugin = GeminiPlugin(**placement_config) @@ -70,24 +50,22 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - booster.save_model(bert_model, - pretrained_path, - True, - True, - '', (model_size / 3), - use_safetensors=use_safetensors) + booster.save_model( + bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors + ) dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32), - new_bert_model.state_dict(), False) + check_state_dict_equal( + bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False + ) @clear_cache_before_run() -@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS) -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) -@parameterize('size_per_shard', [32]) +@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() @@ -102,7 +80,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -123,13 +101,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), - False) + check_state_dict_equal( + optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False + ) # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -143,13 +122,13 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 4569ea12d82d..d46e5380d944 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -19,10 +19,9 @@ @clear_cache_before_run() -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) def exam_torch_load_from_gemini(shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) @@ -33,7 +32,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -60,8 +59,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - new_model.state_dict(), False) + check_state_dict_equal( + model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + new_model.state_dict(), + False, + ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False) @@ -69,7 +71,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -82,10 +84,9 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): @clear_cache_before_run() -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) def exam_gemini_load_from_torch(shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = TorchDDPPlugin() @@ -96,7 +97,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -123,8 +124,11 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - model.state_dict(), False) + check_state_dict_equal( + new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + model.state_dict(), + False, + ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) old_state_dict = optimizer.state_dict() @@ -132,18 +136,19 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): # Comparison of param_groups needs special care here, # since not all hyperparameters in Adam are used by HybridAdam - hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay'] - for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): + hyperparameters_to_examine = ["params", "lr", "betas", "eps", "weight_decay"] + for old_group, new_group in zip(old_state_dict["param_groups"], new_state_dict["param_groups"]): for k in hyperparameters_to_examine: - assert k in old_group and k in new_group, \ - f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert ( + k in old_group and k in new_group + ), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" assert old_group[k] == new_group[k] - check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) + check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False) # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -157,13 +162,13 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_torch_load_from_gemini() exam_gemini_load_from_torch() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 0976d4503a61..2a046a298dd7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -5,7 +5,6 @@ from torch.optim import Adam from torchvision.models import resnet18 -from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize @@ -18,7 +17,7 @@ @clear_cache_before_run() -@parameterize('use_safetensors', [True, False]) +@parameterize("use_safetensors", [True, False]) def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() @@ -59,7 +58,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) -@pytest.mark.parametrize('use_safetensors', [True, False]) +@pytest.mark.parametrize("use_safetensors", [True, False]) def test_sharded_model_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() @@ -75,11 +74,9 @@ def test_sharded_model_checkpoint(use_safetensors: bool): # create a temp file for checkpoint if use_safetensors: - suffix = ".safetensors" - SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + pass else: - suffix = ".bin" - WEIGHTS_INDEX_NAME = "model.bin.index.json" + pass model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -103,7 +100,6 @@ def test_sharded_model_checkpoint(use_safetensors: bool): def test_sharded_optimizer_checkpoint(): - # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -162,16 +158,11 @@ def test_sharded_optimizer_checkpoint(): def test_sharded_optimizer_multiple_param_groups(): - # create a model and optimizer model = resnet18() - optimizer = Adam([{ - 'params': model.layer1.parameters() - }, { - 'params': model.layer2.parameters(), - 'lr': 0.002 - }], - lr=0.001) + optimizer = Adam( + [{"params": model.layer1.parameters()}, {"params": model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) # create test data sample x = torch.randn(1, 3, 224, 224) @@ -194,13 +185,9 @@ def test_sharded_optimizer_multiple_param_groups(): # create new model new_model = resnet18() - new_optimizer = Adam([{ - 'params': new_model.layer1.parameters() - }, { - 'params': new_model.layer2.parameters(), - 'lr': 0.002 - }], - lr=0.001) + new_optimizer = Adam( + [{"params": new_model.layer1.parameters()}, {"params": new_model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index e43908e0c651..e8bb8f9e3475 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -22,37 +22,26 @@ # TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() -@parameterize('shard', [True]) -@parameterize('model_name', ['transformers_gpt']) -@parameterize('size_per_shard', [32]) -@parameterize('test_config', [{ - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 2, - 'pp_size': 1, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize("shard", [True]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) +@parameterize( + "test_config", + [ + { + "tp_size": 4, + "pp_size": 1, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 2, "num_microbatches": 4, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 2, "pp_size": 1, "zero_stage": 2, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ], +) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): - - (model_fn, data_gen_fn, output_transform_fn, loss_fn, - _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) criterion = loss_fn plugin = HybridParallelPlugin(**test_config) booster = Booster(plugin=plugin) @@ -65,10 +54,10 @@ def _criterion(outputs, inputs): def _preprocess_data(data): if booster.plugin.stage_manager is not None: for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) return iter([data]) else: return {k: v.cuda() for k, v in data.items()} @@ -80,12 +69,9 @@ def _preprocess_data(data): data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) else: output = model(**_preprocess_data(data)) loss = criterion(output) @@ -94,7 +80,6 @@ def _preprocess_data(data): optimizer.step() with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) @@ -115,18 +100,12 @@ def _preprocess_data(data): model.train() new_model.train() if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) - booster.execute_pipeline(_preprocess_data(data), - new_model, - _criterion, - new_optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) + booster.execute_pipeline( + _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + ) else: old_model_loss = criterion(model(**_preprocess_data(data))) optimizer.backward(old_model_loss) @@ -141,10 +120,9 @@ def _preprocess_data(data): if stage_manager is None or stage_manager.is_first_stage(): assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, - new_model.unwrap().h[0].mlp.c_fc.weight.data, - atol=5e-3, - rtol=5e-3) + assert_close_loose( + model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 + ) dist.barrier() Randomizer.reset_index() @@ -153,12 +131,12 @@ def _preprocess_data(data): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_hybrid_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 7ee733b26b3f..8a4724c8a82c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -20,9 +20,9 @@ # stage 1 and 2 process the optimizer/mode the same way # only test 2 is fine @clear_cache_before_run() -@parameterize('stage', [2]) -@parameterize('shard', [True, False]) -@parameterize('offload', [False, True]) +@parameterize("stage", [2]) +@parameterize("shard", [True, False]) +@parameterize("offload", [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) @@ -31,7 +31,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): optimizer = HybridAdam((model.parameters()), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - x = torch.randn(1, 3, 224, 224, device='cuda') + x = torch.randn(1, 3, 224, 224, device="cuda") output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -60,15 +60,16 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): padding = new_optimizer._param_store.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] - assert torch.equal(working_shard, - master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)) + assert torch.equal( + working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) + ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() torch.cuda.empty_cache() diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index bd041a5e2fd3..c3c30e666b10 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -1,5 +1,3 @@ -import os - import pytest import torch import torch.distributed as dist @@ -20,18 +18,19 @@ @clear_cache_before_run() -@parameterize('model_name', ['transformers_gpt']) -@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("plugin_type", ["ddp", "zero", "gemini"]) def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): - (model_fn, data_gen_fn, output_transform_fn, loss_fn, - _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) criterion = loss_fn - if plugin_type == 'ddp': + if plugin_type == "ddp": plugin = TorchDDPPlugin() - elif plugin_type == 'zero': + elif plugin_type == "zero": plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) - elif plugin_type == 'gemini': + elif plugin_type == "gemini": plugin = GeminiPlugin(precision="fp16", initial_scale=32) else: raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") @@ -44,7 +43,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) loss = criterion(output) @@ -52,7 +51,6 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per optimizer.step() with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() @@ -62,9 +60,10 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - if plugin_type == 'gemini': - check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), - new_model.unwrap().state_dict(only_rank_0=False), False) + if plugin_type == "gemini": + check_state_dict_equal( + model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False + ) else: check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) dist.barrier() @@ -72,12 +71,12 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_from_pretrained() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_huggingface_compatibility(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 14332b5b3fca..eeb04df0f42d 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -12,8 +12,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -@parameterize('shard', [True, False]) -@parameterize('size_per_shard', [16, 128]) +@parameterize("shard", [True, False]) +@parameterize("size_per_shard", [16, 128]) def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) @@ -27,7 +27,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): assert isinstance(optimizer, OptimizerWrapper) x = torch.randn(4, 3, 224, 224) - x = x.to('cuda') + x = x.to("cuda") output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -47,9 +47,9 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): new_model = resnet18() new_optimizer = SGD((new_model.parameters()), lr=0.001) new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) - new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model, - new_optimizer, - lr_scheduler=new_scheduler) + new_model, new_optimizer, _, _, new_scheduler = booster.boost( + new_model, new_optimizer, lr_scheduler=new_scheduler + ) booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) @@ -61,7 +61,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_checkpointIO() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 2b6090bb1e29..dd41f8185c2b 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -1,7 +1,6 @@ import pytest import torch from packaging import version -from torch import nn from torch.optim import SGD from torchvision.models import resnet18 from utils import shared_tempdir @@ -9,11 +8,10 @@ import colossalai from colossalai.booster import Booster -if version.parse(torch.__version__) >= version.parse('1.12.0'): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin -from colossalai.testing import rerun_if_address_is_in_use, spawn, check_state_dict_equal +from colossalai.testing import rerun_if_address_is_in_use, spawn def compare_nested_dict(dict1, dict2): @@ -72,15 +70,16 @@ def run_model(): booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) full_msd = fsdp_model.state_dict() - #full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) + # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) sharded_osd = optimizer.state_dict() import copy + sharded_osd = copy.deepcopy(sharded_osd) run_model() full_msd_updated = fsdp_model.state_dict() - #full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + # full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) sharded_osd_updated = optimizer.state_dict() assert not compare_nested_dict(sharded_osd, sharded_osd_updated) @@ -92,9 +91,9 @@ def run_model(): booster.load_optimizer(optimizer, optim_ckpt_path) full_msd_restore = fsdp_model.state_dict() - #full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + # full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) sharded_osd_restore = optimizer.state_dict() - + assert compare_nested_dict(sharded_osd, sharded_osd_restore) assert compare_nested_dict(full_msd_restore, full_msd) outputs_sec = fsdp_model(inputs) @@ -103,11 +102,11 @@ def run_model(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_ckpt() -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_ckpt(): spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/utils.py b/tests/test_checkpoint_io/utils.py index 2d35e157f446..d14fc944267c 100644 --- a/tests/test_checkpoint_io/utils.py +++ b/tests/test_checkpoint_io/utils.py @@ -15,7 +15,7 @@ def shared_tempdir() -> Iterator[str]: try: obj = [tempdir] dist.broadcast_object_list(obj, src=0) - tempdir = obj[0] # use the same directory on all ranks + tempdir = obj[0] # use the same directory on all ranks yield tempdir finally: dist.barrier() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index bb818a275879..ab61cdae5bb0 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,5 +1,3 @@ -import torch - from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -8,7 +6,7 @@ def check_device_mesh_manager(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") device_mesh_manager = DeviceMeshManager() # TODO(ver217): this test is strictly relies on hardware, temporary skip it # device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) @@ -20,7 +18,7 @@ def check_device_mesh_manager(rank, world_size, port): physical_ids=[0, 1, 2, 3], mesh_shape=(2, 2), ) - device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + device_mesh_with_shape = device_mesh_manager.create_device_mesh("1", device_mesh_info_with_shape) assert device_mesh_with_shape.shape == (2, 2) assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] @@ -30,5 +28,5 @@ def test_device_mesh_manager(): spawn(check_device_mesh_manager, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh_manager() diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 2304203d1e04..08542d1f64fa 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -15,13 +15,15 @@ def check_process_group_mesh_with_gpc(): # check world size assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( - TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}' + TP_DIM + ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}" assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) # check locak rank (coordinate) assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( - TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}' + TP_DIM + ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}" assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) @@ -37,21 +39,21 @@ def check_process_group_mesh_with_gpc(): coord = pg_mesh.coordinate() if not gpc.is_first_rank(ParallelMode.TENSOR): assert coord[TP_DIM] != 0 - prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:] + prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :] assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) if not gpc.is_first_rank(ParallelMode.PIPELINE): assert coord[PP_DIM] != 0 - prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:] + prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :] assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) # check next rank if not gpc.is_last_rank(ParallelMode.TENSOR): assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 - next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:] + next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :] assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) if not gpc.is_last_rank(ParallelMode.PIPELINE): assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 - next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:] + next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :] assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) @@ -108,35 +110,49 @@ def check_process_group_mesh_with_cases(): # check prev rank if RANK_TO_COORDINATE[rank][TP_DIM] != 0: - prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \ - RANK_TO_COORDINATE[rank][TP_DIM + 1:] + prev_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank if RANK_TO_COORDINATE[rank][PP_DIM] != 0: - prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \ - RANK_TO_COORDINATE[rank][PP_DIM + 1:] + prev_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank # check next rank if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: - next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \ - RANK_TO_COORDINATE[rank][TP_DIM + 1:] + next_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: - next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \ - RANK_TO_COORDINATE[rank][PP_DIM + 1:] + next_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank def run_dist(rank, world_size, port): - colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))), - rank=rank, - world_size=world_size, - port=port, - host='localhost') + colossalai.launch( + config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))), + rank=rank, + world_size=world_size, + port=port, + host="localhost", + ) # TODO(ver217): this function should be removed when gpc is removed # check_process_group_mesh_with_gpc() check_process_group_mesh_with_cases() @@ -147,5 +163,5 @@ def test_process_group_mesh(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_process_group_mesh() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py index 08ca108281b9..b9af7ab41a55 100644 --- a/tests/test_config/sample_config.py +++ b/tests/test_config/sample_config.py @@ -3,23 +3,23 @@ train_data = dict( dataset=dict( - type='CIFAR10Dataset', - root='/path/to/data', + type="CIFAR10Dataset", + root="/path/to/data", download=True, transform_pipeline=[ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ] + dict(type="RandomResizedCrop", size=224), + dict(type="RandomHorizontalFlip"), + dict(type="ToTensor"), + dict(type="Normalize", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ], ), dataloader=dict( batch_size=64, pin_memory=True, num_workers=4, sampler=dict( - type='DataParallelSampler', + type="DataParallelSampler", shuffle=True, - ) - ) + ), + ), ) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 38b5e3f5f4fc..66e473459445 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -3,16 +3,15 @@ from pathlib import Path -import pytest - from colossalai.context.config import Config def test_load_config(): - filename = Path(__file__).parent.joinpath('sample_config.py') + filename = Path(__file__).parent.joinpath("sample_config.py") config = Config.from_file(filename) - assert config.train_data, 'cannot access train data as attribute' - assert config.train_data.dataset, 'cannot access grandchild attribute' - assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ - f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' + assert config.train_data, "cannot access train data as attribute" + assert config.train_data.dataset, "cannot access grandchild attribute" + assert isinstance( + config.train_data.dataset.transform_pipeline[0], dict + ), f"expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}" diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index ab933ed57d0d..f4a88f79c37b 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -8,7 +8,7 @@ def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) ab_dict = profiler.profile_ab() for _, (alpha, beta) in ab_dict.items(): @@ -17,11 +17,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion fails for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 590d6966bff6..af44af5d9097 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -27,8 +27,8 @@ def check_1d_device_mesh(): # checks assert device_mesh.shape == [4] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' - assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, "Expected 1 axis for the process group dict" + assert device_mesh.get_process_group(axis=0) == process_group, "Expected world process group" assert device_mesh.is_initialized assert device_mesh.num_devices == 4 assert device_mesh.is_initialized @@ -43,10 +43,10 @@ def check_2d_device_mesh(): first_col_ranks = [0, 2] second_col_ranks = [1, 3] - first_row_pg = dist.new_group(first_row_ranks, backend='nccl') - second_row_pg = dist.new_group(second_row_ranks, backend='nccl') - first_col_pg = dist.new_group(first_col_ranks, backend='nccl') - second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + first_row_pg = dist.new_group(first_row_ranks, backend="nccl") + second_row_pg = dist.new_group(second_row_ranks, backend="nccl") + first_col_pg = dist.new_group(first_col_ranks, backend="nccl") + second_col_pg = dist.new_group(second_col_ranks, backend="nccl") # check for current_rank = dist.get_rank() @@ -65,9 +65,9 @@ def check_2d_device_mesh(): # checks assert device_mesh.shape == [2, 2] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' - assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' - assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, "Expected 2 axes for the process group dict" + assert device_mesh.get_process_group(axis=0) == col_pg, "Expected column process group" + assert device_mesh.get_process_group(axis=1) == row_pg, "Expected row process group" assert device_mesh.num_devices == 4 assert device_mesh.is_initialized assert device_mesh.logical_mesh_id is None @@ -75,7 +75,7 @@ def check_2d_device_mesh(): def check_init_from_process_group(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @pytest.mark.dist @@ -84,6 +84,6 @@ def test_device_mesh_from_process_group(): spawn(check_init_from_process_group, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh() test_device_mesh_from_process_group() diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index 52604b9c6a49..34f2aacc18b2 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -8,7 +8,7 @@ def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() @@ -20,11 +20,11 @@ def check_extract_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index c18bf56752fb..3b398a917182 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -9,7 +9,7 @@ def check_layer(rank, world_size, port): - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -33,5 +33,5 @@ def test_logical_pg(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_logical_pg() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index b22a76eabc2f..d9d4e79c1f57 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -8,7 +8,7 @@ def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) best_logical_mesh = profiler.search_best_logical_mesh() @@ -20,11 +20,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 6a12f5bc848e..10fe9815541c 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -11,15 +11,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -30,7 +31,6 @@ def forward(self, x): class relu(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU(inplace=True) @@ -40,7 +40,6 @@ def forward(self, x): class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp1 = MLP() @@ -65,7 +64,7 @@ def forward(self, x, y): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -87,26 +86,31 @@ def _run_act_ckpt_codegen(rank, world_size, port): # check ops are annotated with ckpt # also annotate the selected node for offloading - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -115,7 +119,7 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) @@ -123,7 +127,7 @@ def test_act_ckpt_codegen(): def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -144,25 +148,30 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): graph._python_code = python_code_with_activation_checkpoint.__get__(graph) # check ops are annotated with ckpt - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -171,12 +180,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index ebcfb4d7b633..f1e87e5ed140 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -9,15 +9,14 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version - from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -33,7 +32,7 @@ def forward(self, x): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -54,27 +53,34 @@ def _run_act_ckpt_codegen(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -83,14 +89,14 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -111,27 +117,34 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -140,12 +153,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index dac59c23655e..da1e73ec3dfe 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -12,15 +12,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MyNet(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.linear0 = torch.nn.Linear(4, 4) @@ -50,7 +51,6 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): - # test forward non_fx_out = model(data) fx_out = gm(data) @@ -66,7 +66,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T def _run_offload_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -83,37 +83,40 @@ def _run_offload_codegen(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_offload_codegen, 1) @@ -121,7 +124,7 @@ def test_act_ckpt_codegen(): def _run_offload_codegen_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -139,31 +142,34 @@ def _run_offload_codegen_torch11(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 96cf5198da10..efef368bdd45 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn from torch.fx import GraphModule @@ -9,7 +8,6 @@ class Conv1D(nn.Module): - def __init__(self, nf, nx): super().__init__() self.nf = nf @@ -27,10 +25,9 @@ def forward(self, x): @clear_cache_before_run() def test_coloproxy(): - tracer = ColoTracer() model = Conv1D(3, 3) - input_sample = {'x': torch.rand(3, 3).to('meta')} + input_sample = {"x": torch.rand(3, 3).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -38,7 +35,7 @@ def test_coloproxy(): node = list(gm.graph.nodes)[0] proxy = ColoProxy(node=node, tracer=tracer) - proxy.meta_data = torch.empty(4, 2, device='meta') + proxy.meta_data = torch.empty(4, 2, device="meta") assert len(proxy) == 4 assert proxy.shape[0] == 4 and proxy.shape[1] == 2 @@ -47,5 +44,5 @@ def test_coloproxy(): assert proxy.size(0) == 4 -if __name__ == '__main__': +if __name__ == "__main__": test_coloproxy() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index d3daadd71406..00721ca86ade 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -17,7 +17,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -36,7 +35,7 @@ def forward(self, x): @clear_cache_before_run() def test_comm_size_compute(): model = MLP(MODEL_DIM) - input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device="meta") gm = symbolic_trace(model) if is_compatible: input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) @@ -49,5 +48,5 @@ def test_comm_size_compute(): assert comm_size == 128 -if __name__ == '__main__': +if __name__ == "__main__": test_comm_size_compute() diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index 175b69dd96fe..eece451a706f 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,15 +1,11 @@ import torch -from torch.fx import GraphModule -import colossalai from colossalai.fx import ColoTracer -from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -43,11 +39,11 @@ def test_graph_manipulation(): assert leaf_nodes == set([l4, l5]) assert top_nodes == set([l1, l2]) for node in graph.nodes: - if node.op in ('placeholder', 'output'): - assert not hasattr(node, 'bfs_level') + if node.op in ("placeholder", "output"): + assert not hasattr(node, "bfs_level") else: assert node.bfs_level == compare_dict[node] -if __name__ == '__main__': +if __name__ == "__main__": test_graph_manipulation() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index e490522dbf15..7fc7eb4df64b 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -13,35 +13,41 @@ aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -50,15 +56,20 @@ (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -72,7 +83,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -80,5 +91,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 7aed6fd4597b..6091c4b6be2f 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -23,31 +23,40 @@ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models(): for m in tm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models(): for m in tmm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() test_timm_models() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 61614f8a6623..ba9617a38380 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -23,31 +23,40 @@ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models_trace(): for m in tm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models_trace(): for m in tmm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models_trace() test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index a12512696a73..659949e87002 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,18 +23,18 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): @clear_cache_before_run() def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) - input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="meta") if is_compatible_with_meta(): - input_sample = MetaTensor(input_sample, fake_device='cpu') + input_sample = MetaTensor(input_sample, fake_device="cpu") orig_output = model(input_sample) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: - if node.op == 'placeholder': - meta_check(node.meta['tensor_meta'], input_sample) - if node.op == 'output': - meta_check(node.meta['tensor_meta'], orig_output) + if node.op == "placeholder": + meta_check(node.meta["tensor_meta"], input_sample) + if node.op == "output": + meta_check(node.meta["tensor_meta"], orig_output) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 29135b45f997..6d890f59d5c5 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -13,7 +13,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -29,12 +28,12 @@ def forward(self, x): return x -CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) +CONFIG = dict(parallel=dict(tensor=dict(mode="1d", size=2))) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") input_tensor = torch.rand(2, 16).cuda() model = MLP(16).cuda() symbolic_traced = symbolic_trace(model) @@ -55,5 +54,5 @@ def test_1d(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py index 3afc6c97e2bb..b86c71db85c2 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -26,7 +27,7 @@ def split_model_and_compare_output(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -49,16 +50,16 @@ def split_model_and_compare_output(model, data_gen): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) # get output tensor from HFOutput datastructure - if 'logits' in output: - output_to_compare = output['logits'] - elif 'prediction_logits' in output: - output_to_compare = output['prediction_logits'] + if "logits" in output: + output_to_compare = output["logits"] + elif "prediction_logits" in output: + output_to_compare = output["prediction_logits"] else: - output_to_compare = output['last_hidden_state'] + output_to_compare = output["last_hidden_state"] # compare output if isinstance(output_part1, torch.Tensor): diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 6ef861bdefbe..d15081b0b3ad 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, @@ -17,12 +17,14 @@ def test_single_sentence_albert(): transformers.AlbertForTokenClassification, ] - config = transformers.AlbertConfig(vocab_size=100, - embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) + config = transformers.AlbertConfig( + vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +38,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index a7550413fac8..3588033d1ecd 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, @@ -18,11 +18,9 @@ def test_single_sentence_bert(): transformers.BertForTokenClassification, ] - config = transformers.BertConfig(vocab_size=100, - hidden_size=128, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=256) + config = transformers.BertConfig( + vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256 + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +34,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index 6181c5c0706a..d2533aea4003 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -9,14 +9,14 @@ NUM_CHUNKS = 1 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, transformers.GPT2LMHeadModel, transformers.GPT2DoubleHeadsModel, transformers.GPT2ForTokenClassification, - # transformers.GPT2ForSequenceClassification, # not supported yet + # transformers.GPT2ForSequenceClassification, # not supported yet ] config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) @@ -32,5 +32,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index 1a9b36be82bd..e67628d10364 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_opt(): MODEL_LIST = [ transformers.OPTModel, @@ -27,5 +27,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index 16d0163746b3..dc36fdb13152 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_t5(): MODEL_LIST = [ transformers.T5Model, @@ -39,5 +39,5 @@ def data_gen_for_encoder_only(): split_model_and_compare_output(model, data_gen_func) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 6fb1f6f4bb23..c4fe5547ed8d 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -4,9 +4,8 @@ from timm_utils import split_model_and_compare_output -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_without_control_flow(): - MODEL_LIST = [ tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, @@ -25,24 +24,28 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True MODEL_LIST_WITH_CONTROL_FLOW = [ - tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, - tm.swin_transformer.swin_base_patch4_window7_224 + tm.convnext.convnext_base, + tm.vgg.vgg11, + tm.dpn.dpn68, + tm.densenet.densenet121, + tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224, ] data = torch.rand(2, 3, 224, 224) - meta_args = {'x': data.to('meta')} + meta_args = {"x": data.to("meta")} for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: model = model_cls() split_model_and_compare_output(model, data, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models_without_control_flow() test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py index aa870e5c7a65..e1182c8d4978 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -46,6 +47,6 @@ def split_model_and_compare_output(model, data, meta_args=None): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py index 16da56250dc3..7c420ef2385a 100644 --- a/tests/test_fx/test_pipeline/test_topo/test_topo.py +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") def test_opt(): MODEL_LIST = [ MLP, @@ -15,10 +15,7 @@ def test_opt(): ] CONFIGS = [ - { - 'dim': 10, - 'layers': 12 - }, + {"dim": 10, "layers": 12}, transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), ] @@ -45,5 +42,5 @@ def data_gen_OPT(): check_topo(top_mod, topo) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py index db6cadfc544c..6a69181a6d26 100644 --- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -6,7 +6,7 @@ from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.middleware import Partition, Topo from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology MANUAL_SEED = 0 @@ -16,11 +16,10 @@ class MLP(torch.nn.Module): - def __init__(self, config={}): super().__init__() - dim = config['dim'] - layers = config['layers'] + dim = config["dim"] + layers = config["layers"] self.layers = torch.nn.ModuleList() for _ in range(layers): @@ -41,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -55,7 +54,7 @@ def split_model_and_get_DAG(model, data_gen): topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return top_module, split_submodules[0]._topo @@ -64,7 +63,7 @@ def check_input(top_module, input_partition: Partition): partition_output = input_partition.get_output_vals() arg_pos = 0 for node in top_module.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": cur_checkee = partition_output[arg_pos] to_partition_and_offset = cur_checkee.get() assert len(to_partition_and_offset) == len(node.users.keys()) @@ -80,7 +79,7 @@ def check_submod(top_module, part_id, mid_partition: Partition): cnt = 1 cur_node = None for node in top_module.graph.nodes: - if node.name.startswith('submod'): + if node.name.startswith("submod"): cnt += 1 if cnt == part_id: cur_node = node diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index 5d47be2c7bea..063e51309503 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -19,14 +19,21 @@ torch.backends.cudnn.deterministic = True -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_torchvision_models(): MODEL_LIST = [ - tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.efficientnet_b0, + tm.mnasnet0_5, ] - if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + if version.parse(torchvision.__version__) >= version.parse("0.12.0"): MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) tracer = ColoTracer() @@ -57,10 +64,10 @@ def test_torchvision_models(): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index 1078dac9db7c..7a5a397500bb 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,10 +1,6 @@ -import pytest import torch -import torch.nn as nn from torch.fx import symbolic_trace -import colossalai -import colossalai.nn as col_nn from colossalai.fx.passes.adding_split_node_pass import ( balanced_split_pass, balanced_split_pass_v2, @@ -19,7 +15,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -53,5 +48,5 @@ def test_pipeline_passes(): pipeline_pass_test_helper(model, data, uniform_split_pass) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_passes() diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py index aec32268484f..9e4214876ba7 100644 --- a/tests/test_fx/test_profiler/gpt_utils.py +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -1,26 +1,29 @@ -import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -30,7 +33,6 @@ def forward(self, input_ids, attention_mask): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index b5a6bbe8bf18..28409696ca55 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -1,9 +1,9 @@ -from typing import Optional, Tuple, Union +from typing import Tuple import torch import torch.fx import torchvision.models as tm -from gpt_utils import gpt2_medium, gpt2_xl +from gpt_utils import gpt2_medium from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -33,18 +33,18 @@ def extract_forward_flops(gm: torch.fx.GraphModule): fwd_flop = 0 bwd_flop = 0 for node in gm.graph.nodes: - fwd_flop += node.meta.get('fwd_flop', 0) - bwd_flop += node.meta.get('bwd_flop', 0) + fwd_flop += node.meta.get("fwd_flop", 0) + bwd_flop += node.meta.get("bwd_flop", 0) return fwd_flop, bwd_flop -def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device="cuda"): data = torch.rand(batch_size, *shape, device=device) label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) return data, label -def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): +def gen_gpt_data(batch_size, seq_len, vocab_size, device="cpu"): input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) attention_mask = torch.ones_like(input_ids, device=device) return input_ids, attention_mask @@ -96,7 +96,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule): param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 for n in range(NUM_STEPS): torch.cuda.reset_peak_memory_stats() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="cuda:0") # If we need to dive deep into the memory usage by # inspecting `saved_tensor_hooks` @@ -125,21 +125,56 @@ def run_gpt_forward(gm: torch.fx.GraphModule): return forward_mem, param_mem -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_meta_info_prop(): for m in [ - tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, - tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, - tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, - tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, - tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, - tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, - tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + tm.alexnet, + tm.resnet18, + tm.resnet34, + tm.resnet50, + tm.resnet101, + tm.resnet152, + tm.densenet121, + tm.densenet161, + tm.densenet169, + tm.densenet201, + tm.convnext_tiny, + tm.convnext_small, + tm.convnext_base, + tm.convnext_large, + tm.wide_resnet50_2, + tm.wide_resnet101_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, + tm.shufflenet_v2_x0_5, + tm.shufflenet_v2_x1_0, + tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, + tm.mobilenet_v2, + tm.mobilenet_v3_small, + tm.mobilenet_v3_large, + tm.resnext50_32x4d, + tm.resnext101_32x8d, + tm.resnext101_64x4d, + tm.vit_b_16, + tm.vit_b_32, + tm.vit_h_14, + tm.vit_l_16, + tm.vit_l_32, + tm.vgg11, + tm.vgg11_bn, + tm.vgg13, + tm.vgg13_bn, + tm.vgg16, + tm.vgg16_bn, + tm.vgg19, + tm.vgg19_bn, ]: model = m().cuda() model.train() - data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device="meta"), fake_device="cuda:0") gm = symbolic_trace(model) interp = MetaInfoProp(gm) interp.propagate(data) @@ -150,22 +185,22 @@ def test_meta_info_prop(): concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_gpt_meta_info_prop(): for m in [gpt2_medium]: model = m().cuda() model.train() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') - graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="meta") + graph = ColoTracer().trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + interp.propagate(MetaTensor(data, fake_device="cuda:0"), MetaTensor(mask, fake_device="cuda:0")) model.cpu() fwd_flop, bwd_flop = extract_forward_flops(gm) @@ -174,11 +209,11 @@ def test_gpt_meta_info_prop(): meta_forward_mem, meta_param_mem = extract_forward_mem(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() test_gpt_meta_info_prop() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index 632ab8c09750..e7dcf07aafb4 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint @@ -8,7 +7,6 @@ class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -22,7 +20,6 @@ def forward(self, x): # Simple module for demonstration class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp_1 = MLP() @@ -46,20 +43,20 @@ def test_activation_checkpoint_annotation(): gm = GraphModule(module, graph) for node in gm.graph.nodes: - if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 0 + if node.name in ["mlp_1_linear1", "mlp_1_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 0 for node in gm.graph.nodes: - if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 1 + if node.name in ["mlp_2_linear1", "mlp_2_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 1 tracer = ColoTracer(trace_act_ckpt=False) graph = tracer.trace(module) gm = GraphModule(module, graph) for node in gm.graph.nodes: - assert not hasattr(node, 'activation_checkpoint') + assert not hasattr(node, "activation_checkpoint") -if __name__ == '__main__': +if __name__ == "__main__": test_activation_checkpoint_annotation() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index 2f88d8c784e8..e53894bdfd71 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -5,7 +5,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -18,13 +17,11 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -45,7 +42,7 @@ def test_linear_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(3, 3).to("meta")}) # def forward(self, x : torch.Tensor): # linear_weight = self.linear.weight # linear_bias = self.linear.bias @@ -57,9 +54,9 @@ def test_linear_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] linear_node = node_list[3] @@ -83,7 +80,7 @@ def test_conv_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(4, 3, 64, 64).to("meta")}) # def forward(self, x : torch.Tensor): # conv_weight = self.conv.weight # conv_bias = self.conv.bias @@ -97,9 +94,9 @@ def test_conv_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] conv_node = node_list[3] @@ -112,6 +109,6 @@ def test_conv_module(): assert add_node._meta_data.shape == (4, 6, 63, 63) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_module() test_conv_module() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index 820729dadb3e..f0c261c39db5 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -7,7 +7,6 @@ class ControlFlowModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(10, 10) @@ -27,16 +26,12 @@ def forward(self, x, y): def test_control_flow(): model = ControlFlowModel() tracer = Tracer() - graph_branch_true = tracer.trace(model, - meta_args={ - 'x': torch.rand(4, 10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) - graph_branch_false = tracer.trace(model, - meta_args={ - 'x': torch.rand(10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) + graph_branch_true = tracer.trace( + model, meta_args={"x": torch.rand(4, 10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) + graph_branch_false = tracer.trace( + model, meta_args={"x": torch.rand(10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) @@ -56,5 +51,5 @@ def test_control_flow(): assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) -if __name__ == '__main__': +if __name__ == "__main__": test_control_flow() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index a552e905223d..63f9721e2a65 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -47,5 +47,5 @@ def test_conv(): assert out_transpose_3d.shape == patched_out_transpose_3d.shape -if __name__ == '__main__': +if __name__ == "__main__": test_conv() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index e6f8df2e0af7..4828bb0302c8 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,9 +1,6 @@ from typing import List import torch -from numpy import isin -from torch.fx import GraphModule -from torch.utils._pytree import tree_flatten # from colossalai.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace @@ -20,7 +17,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non inputs = {k: v for k, v in inputs.items() if k not in ignore_data} try: - meta_args = {k: v.to('meta') for k, v in inputs.items()} + meta_args = {k: v.to("meta") for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: @@ -35,4 +32,4 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non if torch.is_tensor(fx_out[k]): assert torch.equal( fx_out[k], non_fx_out[k] - ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}' + ), f"{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}" diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index a1470400ad82..fb093821e488 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -10,15 +10,15 @@ SEQ_LENGTH = 16 -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_albert(): - sub_registry = model_zoo.get_sub_registry('transformers_albert') + sub_registry = model_zoo.get_sub_registry("transformers_albert") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) -if __name__ == '__main__': +if __name__ == "__main__": test_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 7773de480302..91f7b9764e6e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -7,17 +7,17 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_bert(): - sub_registry = model_zoo.get_sub_registry('transformers_bert') + sub_registry = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() if model.__class__.__name__ == "BertForQuestionAnswering": continue - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"]) -if __name__ == '__main__': +if __name__ == "__main__": test_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ac87a7fcb13b..95a464fa0534 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -22,7 +22,7 @@ def trace_and_compare(model_cls, data, output_fn): model.eval() concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)} - meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)} + meta_args = {k: v.to("meta") for k, v in data.items() if torch.is_tensor(v)} gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) # run forward @@ -40,12 +40,12 @@ def assert_fn(ta, tb): assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn) -@pytest.mark.skip(reason='cannot pass this test yet') +@pytest.mark.skip(reason="cannot pass this test yet") @clear_cache_before_run() def test_diffusers(): seed_all(9091, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() @@ -58,12 +58,12 @@ def test_diffusers(): def test_torch_diffusers(): seed_all(65535, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() - output = model(**data) + model(**data) torch.cuda.synchronize() print(f"{name:40s} √") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 1cd3b90db917..7bd8a726f1ac 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -7,10 +7,10 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_gpt(): - sub_registry = model_zoo.get_sub_registry('transformers_gpt') + sub_registry = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() @@ -18,11 +18,11 @@ def test_gpt(): # TODO(ver217): support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them - if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: + if model.__class__.__name__ in ["GPT2DoubleHeadsModel", "GPT2ForQuestionAnswering"]: continue - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index c68b89e82fbe..5f7525d5707b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,14 +7,14 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_opt(): - sub_registry = model_zoo.get_sub_registry('transformers_opt') + sub_registry = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "start_positions", "end_positions"]) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 45e06bc2bbb0..6ccbb14e3d96 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -7,20 +7,20 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_t5(): - sub_registry = model_zoo.get_sub_registry('transformers_t5') + sub_registry = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): if name == "transformers_t5_for_conditional_generation": # cannot trace for loss function yet # so we use a data gen which does not produce labels - data_gen_fn = sub_registry.get('transformers_t5')[1] + data_gen_fn = sub_registry.get("transformers_t5")[1] model = model_fn() - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index ef778e21801a..fe66cbd0ffcc 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -36,12 +36,12 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) @clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") module = torch.nn.Linear(4, 2) _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) # test if the linear patch can catch exception when dimension does not match - data = torch.rand(2, 2, device='meta') + data = torch.rand(2, 2, device="meta") _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) @@ -51,20 +51,20 @@ def test_rnn(): data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 10).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) # test if the rnn patch can catch exception when dimension does not match data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 1).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) @clear_cache_before_run() def test_embedding(): - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") # test layernorm ln = torch.nn.LayerNorm(4) @@ -76,67 +76,71 @@ def test_embedding(): # test batch norm 1d bn1d = torch.nn.BatchNorm1d(4) - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # test batch norm 2d bn2d = torch.nn.BatchNorm2d(4) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # # test batch size 3d bn3d = torch.nn.BatchNorm3d(4) - data = torch.rand(1, 1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) @clear_cache_before_run() @@ -146,35 +150,38 @@ def test_conv1d(): conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv1d = torch.nn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv1d = torch.nn.Conv1d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) def test_conv2d(): @@ -182,40 +189,45 @@ def test_conv2d(): data = torch.rand(2, 3, 4, 4) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv2d = torch.nn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -224,40 +236,45 @@ def test_conv3d(): data = torch.rand(2, 3, 4, 4, 4) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv3d = torch.nn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv3d = torch.nn.Conv3d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -267,21 +284,25 @@ def test_conv_transpose1d(): convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -291,21 +312,25 @@ def test_conv_transpose2d(): convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -315,46 +340,56 @@ def test_conv_transpose3d(): convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() def test_pool1d(): - combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], - [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] + combinations = [ + [torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) data = torch.rand(2, 3, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 3, 4, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -362,29 +397,35 @@ def test_pool1d(): @clear_cache_before_run() def test_pool2d(): - combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], - [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + combinations = [ + [torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) @@ -393,29 +434,35 @@ def test_pool2d(): @clear_cache_before_run() def test_pool3d(): - combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], - [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + combinations = [ + [torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4) @@ -430,19 +477,15 @@ def test_adaptive_pooling_1d(): data = torch.rand(3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -458,19 +501,15 @@ def test_adaptive_pooling_2d(): data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) @clear_cache_before_run() @@ -483,16 +522,12 @@ def test_adaptive_pooling_3d(): data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5, 6) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index e0c5f560c49e..37c2333c0982 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -33,38 +33,34 @@ def test_repeat_interleave(): data = torch.tensor([1, 2, 3]) materialized_output = torch.repeat_interleave(data, repeats=2) repeat_interleave = partial(patch_fn, repeats=2) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) repeat_interleave = partial(patch_fn, repeats=3, dim=1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=True, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=True, output_shape=materialized_output.shape + ) @clear_cache_before_run() diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 98433b8f7c3b..2b3f3e039baf 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -20,7 +20,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # 1. ConViT # 2. NormFreeNet # as they are not supported, let's skip them - if model.__class__.__name__ in ['ConViT', 'NormFreeNet']: + if model.__class__.__name__ in ["ConViT", "NormFreeNet"]: return gm = symbolic_trace(model, meta_args=meta_args) @@ -39,8 +39,9 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" # FIXME(ver217): timm/models/convit.py:71: in forward @@ -49,22 +50,22 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # return self.tracer.to_bool(self) # torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow @pytest.mark.skip("convit is not supported yet") -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_timm_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('timm') + sub_model_zoo = model_zoo.get_sub_registry("timm") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None trace_and_compare(model_fn, data, output_transform_fn, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index 2b7def5bef85..dd94a2546955 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -1,6 +1,5 @@ import pytest import torch -from packaging import version from torchaudio_utils import trace_and_compare from colossalai.testing import clear_cache_before_run @@ -14,11 +13,10 @@ def test_torchaudio_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('torchaudio') + sub_model_zoo = model_zoo.get_sub_registry("torchaudio") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() - trace_and_compare(model, - data_gen_fn, - output_transform_fn, - need_meta=(attribute is not None and attribute.has_control_flow)) + trace_and_compare( + model, data_gen_fn, output_transform_fn, need_meta=(attribute is not None and attribute.has_control_flow) + ) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 239f38680cec..2379372bc3f9 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -6,7 +6,7 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): data = data_gen() concrete_args = data if need_concrete else {} - meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + meta_args = {k: v.to("meta") for k, v in data.items()} if need_meta else {} model.eval() @@ -24,5 +24,6 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, nee for key, fx_output_val in transformed_fx_out.items(): non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index f969c8e6c3da..30c1910855e6 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,31 +31,34 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_deepfm_models(): - deepfm_models = model_zoo.get_sub_registry('deepfm') + deepfm_models = model_zoo.get_sub_registry("deepfm") torch.backends.cudnn.deterministic = True for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 94fb24f33376..71b73236474f 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,37 +31,40 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True - dlrm_models = model_zoo.get_sub_registry('dlrm') + dlrm_models = model_zoo.get_sub_registry("dlrm") for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported # TODO(FrankLeeeee): support this model - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 74cb753e2937..47c6b1186c8e 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -8,7 +8,7 @@ @clear_cache_before_run() def test_torchvision_models(): torch.backends.cudnn.deterministic = True - tv_sub_registry = model_zoo.get_sub_registry('torchvision') + tv_sub_registry = model_zoo.get_sub_registry("torchvision") for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() @@ -36,11 +36,11 @@ def test_torchvision_models(): fx_val = transformed_out[key] non_fx_val = transformed_non_fx_out[key] assert torch.allclose( - fx_val, - non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}' + fx_val, non_fx_val + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}" except Exception as e: print(name, e) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 3d56cc3484a6..2ddc8b6e68e4 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -1,20 +1,6 @@ import copy -import torch -import torch.distributed as dist -from torch import Tensor -from torch import distributed as dist -from torch.distributed import ProcessGroup -from torch.nn import Module -from torch.optim import Adam, Optimizer - -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer._utils import getattr_ -from colossalai.shardformer.policies.auto_policy import Policy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor def build_model( @@ -28,11 +14,13 @@ def build_model( org_model = model_fn() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused, - inference_only=True) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + inference_only=True, + ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 8ecabf69ecf3..5a5d341fc6ba 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,5 +1,3 @@ -import os - import pytest import torch from packaging import version @@ -16,22 +14,27 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 32 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -@parameterize('test_config', [{ - 'tp_size': TP_SIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) def run(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom_for_causal_lm") for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() orig_model = orig_model.half() data = data_gen_fn() - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(do_sample=False) @@ -42,7 +45,7 @@ def run(test_config): def check_bloom(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run() @@ -54,5 +57,5 @@ def test_bloom_infer(): spawn(check_bloom, TP_SIZE) -if __name__ == '__main__': +if __name__ == "__main__": test_bloom_infer() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index cc3cdd2b501b..f24160820e71 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -2,14 +2,12 @@ import pytest import torch -import torch.nn as nn from packaging import version -from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import BloomConfig, BloomForCausalLM from transformers.tokenization_utils_base import BatchEncoding import colossalai from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -19,12 +17,17 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -@parameterize('test_config', [{ - 'tp_size': TP_SIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) def run(test_config): model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) model = BloomForCausalLM(model_config) @@ -32,8 +35,9 @@ def run(test_config): model.to(torch.cuda.current_device()) # 1. check TPInferEngine init and model optimization - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) assert infer_engine.cache_manager is not None @@ -41,13 +45,17 @@ def run(test_config): assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE # 2. check data preparation - input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970], - [80540, 15473, 3331, 11970], [80540, 15473]] + input_ids_list = [ + [80540, 15473, 3331, 11970, 90472, 361, 61335], + [80540, 15473, 3331, 11970], + [80540, 15473, 3331, 11970], + [80540, 15473], + ] batch_size = len(input_ids_list) max_seq_len = max(len(li) for li in input_ids_list) attention_mask = [[0] * max_seq_len for _ in range(batch_size)] for i, li in enumerate(input_ids_list): - attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))] + attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))] data = dict(input_ids=input_ids_list, attention_mask=attention_mask) inputs_batch_encoding = BatchEncoding(data=data) seq_lengths = [len(li) for li in input_ids_list] @@ -78,7 +86,7 @@ def run(test_config): def check_engine(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run() @@ -90,5 +98,5 @@ def test_engine(): spawn(check_engine, TP_SIZE) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index f57c6956f817..f3e2cdf1e18f 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,7 +1,8 @@ import os -from packaging import version + import pytest import torch +from packaging import version from colossalai.inference.tensor_parallel import MemoryManager from colossalai.logging import disable_existing_loggers @@ -14,14 +15,15 @@ HEAD_NUM = 32 HEAD_DIM = 128 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) disable_existing_loggers() size = batch_size * (input_len + output_len) @@ -41,21 +43,24 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l assert torch.equal(prefill_locs, prefill_locs_contiguous) assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill kvcache_manager.alloc_contiguous(batch_size) - assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) + assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False) + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() def test_cache_manager_dist(): - spawn(create_cache_manager, - 4, - batch_size=BATCH_SIZE, - input_len=INPUT_LEN, - output_len=OUTPUT_LEN, - layer_num=LAYER_NUM, - head_num=HEAD_NUM, - head_dim=HEAD_DIM) + spawn( + create_cache_manager, + 4, + batch_size=BATCH_SIZE, + input_len=INPUT_LEN, + output_len=OUTPUT_LEN, + layer_num=LAYER_NUM, + head_num=HEAD_NUM, + head_dim=HEAD_DIM, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_cache_manager_dist() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index aa8874ea4cb0..0e5efe68508a 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,5 +1,4 @@ import os -import warnings import pytest import torch @@ -12,13 +11,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") def init_to_get_rotary(self, base=10000): @@ -34,8 +33,9 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) + ) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -44,20 +44,25 @@ def init_to_get_rotary(self, base=10000): return -@parameterize('test_config', [{ - 'tp_size': TPSIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) def run_llama_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() init_to_get_rotary(orig_model.model, base=10000) orig_model = orig_model.half() data = data_gen_fn() - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(do_sample=False) @@ -68,7 +73,7 @@ def run_llama_test(test_config): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test() diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py index cb12faf6276c..a4d893f8e830 100644 --- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py +++ b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py @@ -1,16 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os import pytest -import numpy as np -from packaging import version - import torch from torch import nn -from torch.nn import functional as F -try: +try: from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm HAS_VLLM_KERNERL = True except: @@ -18,6 +14,7 @@ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") HAS_VLLM_KERNERL = False + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -34,6 +31,7 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): x = hidden_states out = torch.empty_like(x) @@ -45,6 +43,7 @@ def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): ) return out + @pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") def test_rmsnorm(): data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") @@ -56,5 +55,6 @@ def test_rmsnorm(): check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + if __name__ == "__main__": - test_rmsnorm() \ No newline at end of file + test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py index 2a85566c65c6..40451ef6636d 100644 --- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py +++ b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py @@ -1,8 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import pytest from typing import Tuple +import pytest import torch import torch.nn as nn import torch.nn.functional as F @@ -10,17 +10,18 @@ try: from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True -except: +except: print("fall back to original rotary_embedding_neox of huggingface") print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") HAS_VLLM_KERNERL = False def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -49,7 +50,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Create cos and sin embeddings. - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq.float()) emb = torch.cat((freqs, freqs), dim=-1) @@ -64,11 +65,10 @@ def forward( query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: - - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] query_rot = query_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1) @@ -84,6 +84,7 @@ def forward( # Output query/key shape: [num_tokens, num_tokens, head_size] return query, key + def run_rotary_embedding_neox( num_tokens: int, num_heads: int, @@ -93,24 +94,18 @@ def run_rotary_embedding_neox( dtype: torch.dtype, base: int = 10000, ) -> None: - positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') - query = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device='cuda') - key = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device='cuda') + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") # Create the rotary embedding. - inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) t = torch.arange(max_position).float() - freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() @@ -128,7 +123,7 @@ def run_rotary_embedding_neox( dim=rotary_dim, max_position_embeddings=max_position, base=base, - ).to(dtype=dtype, device='cuda') + ).to(dtype=dtype, device="cuda") ref_query, ref_key = ref_rotary_embedding( positions, query.view(num_tokens, num_heads, head_size), @@ -141,6 +136,7 @@ def run_rotary_embedding_neox( assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + @pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") def test_rotary_embedding(): run_rotary_embedding_neox( @@ -149,8 +145,9 @@ def test_rotary_embedding(): head_size=64, max_position=8192, rotary_dim=64, - dtype=torch.float16, + dtype=torch.float16, ) + if __name__ == "__main__": - test_rotary_embedding() \ No newline at end of file + test_rotary_embedding() diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index b081b32b9ad3..0732ace1e04b 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -1,19 +1,18 @@ import math -import numpy as np import torch from torch.nn import functional as F def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): - ''' - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 - ''' + """ + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ xq = xq.view(bs, seqlen, num_head, head_dim) xk = xk.view(bs, seqlen, num_head, head_dim) xv = xv.view(bs, seqlen, num_head, head_dim) mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.] = -100000000.0 + mask[mask == 0.0] = -100000000.0 mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index 344ad078e2e2..7a6c218a6691 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -1,27 +1,24 @@ -import math - import pytest import torch from packaging import version -from torch import nn -from torch.nn import functional as F try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton import bloom_context_attn_fwd from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_bloom_context_attention(): bs = 4 head_num = 8 @@ -46,8 +43,9 @@ def test_bloom_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, - atol=1e-2), "outputs from triton and torch are not matched" + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 + ), "outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index c656f81d2790..34e453f7840e 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -1,25 +1,24 @@ import pytest import torch from packaging import version -from torch import nn try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_kv_cache_copy_op(): - B_NTX = 32 * 2048 head_num = 8 head_dim = 64 @@ -31,8 +30,9 @@ def test_kv_cache_copy_op(): copy_kv_cache_to_dest(cache, dest_index, dest_data) - assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, - atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + assert torch.allclose( + cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3 + ), "copy_kv_cache_to_dest outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py index 94cd704ffeba..7f814e8c9a9f 100644 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -6,30 +6,29 @@ from colossalai.testing.utils import parameterize try: - import triton - import triton.language as tl + pass - from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") -@parameterize('M', [2, 4, 8, 16]) -@parameterize('N', [64, 128]) +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@parameterize("M", [2, 4, 8, 16]) +@parameterize("N", [64, 128]) def test_layer_norm(M, N): dtype = torch.float16 eps = 1e-5 x_shape = (M, N) w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") y_triton = layer_norm(x, weight, bias, eps) y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 4ea6095d4109..be6de6db2471 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -1,27 +1,24 @@ -import math - import pytest import torch from packaging import version -from torch import nn -from torch.nn import functional as F try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton import llama_context_attn_fwd from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_llama_context_attention(): bs = 4 head_num = 8 @@ -45,8 +42,9 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, - atol=1e-3), "outputs from triton and torch are not matched" + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 + ), "outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index d5ecdf684538..7e05ccafbfc4 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -1,14 +1,12 @@ # Adapted from ModelTC https://github.com/ModelTC/lightllm -import time import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd @@ -17,13 +15,13 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_rotary_emb(x, cos, sin): seq_len, h, dim = x.shape - x0 = x[:, :, 0:dim // 2] - x1 = x[:, :, dim // 2:dim] + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] cos = cos.view((seq_len, 1, dim // 2)) sin = sin.view((seq_len, 1, dim // 2)) o0 = x0 * cos - x1 * sin @@ -31,8 +29,9 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 @@ -40,10 +39,10 @@ def test_rotary_emb(): dtype = torch.half # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") cos_shape = (SEQ_LEN, HEAD_DIM // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") # forward pass y_torch = torch_rotary_emb(x, cos, sin) rotary_embedding_fwd(x, cos, sin) diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py index 9692737a05a0..9bdec86645b2 100644 --- a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py +++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py @@ -1,24 +1,27 @@ import pytest -from packaging import version import torch -from torch import nn import torch.nn.functional as F +from packaging import version try: import triton - import triton.language as tl - from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel + from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_qkv_matmul(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 batches = qkv.shape[0] @@ -26,7 +29,7 @@ def test_qkv_matmul(): num_of_heads = d_model // head_size q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] + k = qkv[:, :, d_model : d_model * 2] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) @@ -36,29 +39,40 @@ def test_qkv_matmul(): k = torch.transpose(k, 1, 2).contiguous() k = torch.transpose(k, 2, 3).contiguous() - torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) + torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) torch_ouput *= 1.2 q, k = q_copy, k_copy batches, M, H, K = q.shape N = k.shape[1] - score_output = torch.empty( - (batches, H, M, N), device=q.device, dtype=q.dtype) + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) K = q.shape[3] qkv_gemm_4d_kernel[grid]( - q, k, score_output, - M, N, K, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(3), k.stride(1), - score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), scale=scale, # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, @@ -69,21 +83,16 @@ def test_qkv_matmul(): check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "the outputs of triton and torch are not matched" - -def self_attention_compute_using_torch(qkv, - input_mask, - scale, - head_size - ): +def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size - + q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) @@ -94,37 +103,36 @@ def self_attention_compute_using_torch(qkv, k = torch.transpose(k, -1, -2).contiguous() - score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output = torch.einsum("bnij,bnjk->bnik", q, k) score_output *= scale - softmax_output = F.softmax(score_output, dim = -1) - res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) + softmax_output = F.softmax(score_output, dim=-1) + res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) res = torch.transpose(res, 1, 2) res = res.contiguous() - return res.view(batches, -1, d_model), score_output, softmax_output -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") -def test_self_atttention_test(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_self_atttention_test(): + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( - qkv.clone(), - input_mask = None, - scale = 1.2, - head_size = 32 - ) + qkv.clone(), input_mask=None, scale=1.2, head_size=32 + ) data_output_triton = self_attention_compute_using_triton( - qkv.clone(), - alibi=None, - head_size=32, - scale=1.2, - input_mask=None, - layer_past=None, - use_flash=False, - triangular=True) + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True, + ) check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) assert check is True, "the triton output is not matched with torch output" @@ -132,4 +140,4 @@ def test_self_atttention_test(): if __name__ == "__main__": test_qkv_matmul() - test_self_atttention_test() \ No newline at end of file + test_self_atttention_test() diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py index 6a244608c43f..43b9c0929c4a 100644 --- a/tests/test_infer_ops/triton/test_softmax.py +++ b/tests/test_infer_ops/triton/test_softmax.py @@ -1,30 +1,31 @@ import pytest -from packaging import version import torch +from packaging import version from torch import nn - try: - import triton - import triton.language as tl from colossalai.kernel.triton.softmax import softmax + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_softmax_op(): data_samples = [ - torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), - torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), - torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16) - ] + torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), + torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), + torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), + ] for data in data_samples: - module = nn.Softmax(dim = -1) + module = nn.Softmax(dim=-1) data_torch_out = module(data) data_triton_out = softmax(data) check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) @@ -32,4 +33,4 @@ def test_softmax_op(): if __name__ == "__main__": - test_softmax_op() \ No newline at end of file + test_softmax_op() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py index aee7944597dc..fc5f8cd6c9dc 100644 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -5,16 +5,16 @@ from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): @@ -23,8 +23,9 @@ def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): keys = xk xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) - scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape( - num_head, -1) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) return scores @@ -37,10 +38,11 @@ def torch_attn_1(xq, xk, seqlen, num_head, head_dim): return logics -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_attn_1(): - import time + pass batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py index f834fedbb0f1..2dd756f2ba91 100644 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ b/tests/test_infer_ops/triton/test_token_attn_2.py @@ -1,20 +1,18 @@ -import math - import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_attn(V, P, bs, seqlen, num_head, head_dim): @@ -25,19 +23,23 @@ def torch_attn(V, P, bs, seqlen, num_head, head_dim): return attn_out -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_token_attn_2(): - import time + pass batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 dtype = torch.float16 V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = torch.empty( - (head_num, batch_size * seq_len), dtype=dtype, - device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size, - seq_len).softmax(-1).reshape(head_num, batch_size * seq_len) + Prob = ( + torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + .normal_(mean=0.4, std=0.2) + .reshape(head_num, batch_size, seq_len) + .softmax(-1) + .reshape(head_num, batch_size * seq_len) + ) attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index e82318965e05..9c7a53798317 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -1,20 +1,18 @@ -import time - import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): @@ -29,10 +27,10 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): return torch.sum(prob * xv, dim=1, keepdim=False) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test(): - Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 dtype = torch.float16 q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py index 08ffe1ca8323..1f97f1674818 100644 --- a/tests/test_infer_ops/triton/test_token_softmax.py +++ b/tests/test_infer_ops/triton/test_token_softmax.py @@ -3,22 +3,22 @@ from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_softmax(): - import torch batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 9d9e9a3a5c76..ea6b16b94785 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -12,7 +12,7 @@ from colossalai.tensor.d_tensor.layout import Layout from tests.kit.model_zoo.registry import ModelAttribute -SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') +SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse("1.12.0") # model_fn, data_gen_fn, output_transform_fn, model_attr TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] @@ -28,18 +28,22 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None: s1 = m1.state_dict() s2 = m2.state_dict() - assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}' + assert len(s1) == len(s2), f"len {len(s1)} vs {len(s2)}" for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): assert n1 == n2 - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" for p1, p2 in zip(m1.parameters(), m2.parameters()): assert p1.requires_grad == p2.requires_grad -def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], - output_transform_fn: Callable[[Any], dict]) -> None: +def assert_forward_equal( + m1: torch.nn.Module, + m2: torch.nn.Module, + data_gen_fn: Callable[[], dict], + output_transform_fn: Callable[[Any], dict], +) -> None: data = data_gen_fn() m1.eval() @@ -57,15 +61,14 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: for key, out1 in transformed_out1.items(): out2 = transformed_out2[key] - assert torch.allclose(out1, out2, atol=1e-5), \ - f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' + assert torch.allclose( + out1, out2, atol=1e-5 + ), f"{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}" -def check_lazy_init(entry: TestingEntry, - seed: int = 42, - verbose: bool = False, - check_forward: bool = False, - default_device: str = 'cpu') -> None: +def check_lazy_init( + entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False, default_device: str = "cpu" +) -> None: model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) @@ -84,15 +87,16 @@ def check_lazy_init(entry: TestingEntry, assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn) if verbose: - print(f'{model.__class__.__name__} pass') + print(f"{model.__class__.__name__} pass") -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, - sharding_spec_dict: dict) -> None: +def assert_dist_model_equal( + model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: dict +) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() - assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}' + assert len(state) == len(distributed_state), f"len {len(state)} vs {len(distributed_state)}" for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): assert n1 == n2 @@ -102,4 +106,4 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) t2.dist_layout = layout t2 = to_global(t2) - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 18a737fcec85..978cf06b55a0 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -4,19 +4,21 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(not SUPPORT_LAZY, reason='requires torch >= 1.12.0') -@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -@pytest.mark.parametrize('default_device', ['cpu', 'cuda']) +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +@pytest.mark.parametrize("subset", ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize("default_device", ["cpu", "cuda"]) def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', - 'torchaudio_hubert_base') or name.startswith('transformers_llama') or name.startswith( - ('transformers_vit', 'transformers_blip2')): + if ( + name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") + or name.startswith("transformers_llama") + or name.startswith(("transformers_vit", "transformers_blip2")) + ): continue check_lazy_init(entry, verbose=True, default_device=default_device) -if __name__ == '__main__': - test_torchvision_models_lazy_init('torchvision') +if __name__ == "__main__": + test_torchvision_models_lazy_init("torchvision") diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py index 54bf6498549c..76f9ff07407f 100644 --- a/tests/test_legacy/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -13,7 +13,7 @@ def check_equal(a, b): """ This function checks if two tensors are equal within tolerance """ - assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f"a = {a}, b = {b}" def run_naive_amp(): @@ -25,7 +25,7 @@ def run_naive_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] + test_models = ["repeated_computed_layers", "nested_model", "resnet18"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -41,9 +41,10 @@ def run_naive_amp(): # inject naive and apex amp naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0) - naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, - naive_amp_config) - apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp( + naive_amp_model, naive_amp_optimizer, naive_amp_config + ) + apex_amp_config = dict(opt_level="O2", loss_scale=128, keep_batchnorm_fp32=False) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -78,7 +79,7 @@ def run_naive_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_naive_amp() @@ -89,5 +90,5 @@ def test_naive_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_naive_amp() diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py index 89810b5d0351..47b303745e4e 100644 --- a/tests/test_legacy/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -18,7 +18,7 @@ def run_torch_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['resnet18', 'simple_net'] + test_models = ["resnet18", "simple_net"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -34,10 +34,10 @@ def run_torch_amp(): # inject torch and apex amp torch_amp_config = dict(init_scale=128, enabled=True) - torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, - torch_amp_optimizer, - amp_config=torch_amp_config) - apex_amp_config = dict(opt_level='O1', loss_scale=128) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp( + torch_amp_model, torch_amp_optimizer, amp_config=torch_amp_config + ) + apex_amp_config = dict(opt_level="O1", loss_scale=128) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -61,7 +61,7 @@ def run_torch_amp(): # check grad # In apex amp, grad is not scaled before backward, but torch amp does for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): - assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config["loss_scale"]) # clip gradient apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) @@ -78,7 +78,7 @@ def run_torch_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_torch_amp() @@ -89,5 +89,5 @@ def test_torch_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_torch_amp() diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py index 4851b3e36bbc..bc243631a6c5 100644 --- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py @@ -16,11 +16,15 @@ def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False) rank = gpc.get_local_rank(ParallelMode.PIPELINE) if rank == 0: - obj = [torch.randn(3,)] + obj = [ + torch.randn( + 3, + ) + ] _send_object(obj, 1) if rank == 1: @@ -30,7 +34,11 @@ def check_layer(rank, world_size, port): _recv_object(3) if rank == 3: - obj = [torch.randn(3,)] + obj = [ + torch.randn( + 3, + ) + ] _send_object(obj, 2) gpc.destroy() @@ -43,5 +51,5 @@ def test_object_list_p2p(): spawn(check_layer, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_object_list_p2p() diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index fccfcd973000..7d2c81972e5a 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -17,41 +17,41 @@ def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") assert dist.get_rank() == gpc.get_global_rank() - print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) + print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size())) check_all_gather() check_reduce_scatter() @@ -67,5 +67,5 @@ def test_comm(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_comm() diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py index a1322e6f28db..69c68c7159e4 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p.py @@ -27,7 +27,7 @@ def check_send_recv_forward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -35,7 +35,7 @@ def check_send_recv_forward(): send_forward(data_to_send) send_forward(data_list_to_send) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE) data_list_recv = recv_forward(TENSOR_SIZE_LIST) data_to_check = data.to(device) @@ -47,7 +47,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) grad_to_check = grad.to(device) @@ -56,7 +56,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -67,7 +67,7 @@ def check_send_recv_backward(): def check_send_recv_forward_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_list_to_send = [] for data_in_list in data_list: data_list_to_send.append(data_in_list.to(device)) @@ -77,7 +77,7 @@ def check_send_recv_forward_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_list_to_send = [] for grad_in_list in grad_list: grad_list_to_send.append(grad_in_list.to(device)) @@ -88,7 +88,7 @@ def check_send_recv_forward_backward(): def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_send_recv_forward() check_send_recv_backward() check_send_recv_forward_backward() @@ -102,5 +102,5 @@ def test_object_list_p2p(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_object_list_p2p() diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py index f805bd19d7e8..eb05ea4839c6 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py @@ -32,7 +32,7 @@ def check_send_recv_forward(): local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if local_rank == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -42,7 +42,7 @@ def check_send_recv_forward(): send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) elif local_rank == 1: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) @@ -60,7 +60,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): disable_existing_loggers() if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) @@ -73,7 +73,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -104,7 +104,7 @@ def check_small_pipeline(): def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") disable_existing_loggers() # check_send_recv_forward() @@ -120,6 +120,6 @@ def test_object_list_p2p(): spawn(check_layer, world_size) -if __name__ == '__main__': +if __name__ == "__main__": disable_existing_loggers() test_object_list_p2p() diff --git a/tests/test_legacy/test_context/configs/parallel_2d_init.py b/tests/test_legacy/test_context/configs/parallel_2d_init.py index 6cf816942fdd..d1203fcdc436 100644 --- a/tests/test_legacy/test_context/configs/parallel_2d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_2d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode='2d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")) diff --git a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py index b946d45b3a91..89e8cd6039f7 100644 --- a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode='2.5d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode="2.5d")) diff --git a/tests/test_legacy/test_context/configs/parallel_3d_init.py b/tests/test_legacy/test_context/configs/parallel_3d_init.py index a1564bbb2d51..f9aa52fa4199 100644 --- a/tests/test_legacy/test_context/configs/parallel_3d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_3d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode='3d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode="3d")) diff --git a/tests/test_legacy/test_context/test_hybrid_parallel.py b/tests/test_legacy/test_context/test_hybrid_parallel.py index 05cd1d294dcd..b9e44bb34362 100644 --- a/tests/test_legacy/test_context/test_hybrid_parallel.py +++ b/tests/test_legacy/test_context/test_hybrid_parallel.py @@ -3,7 +3,6 @@ from pathlib import Path -import pytest import torch from colossalai.legacy import launch @@ -13,7 +12,7 @@ from colossalai.legacy.global_variables import tensor_parallel_env as tp_env from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn -CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) +CONFIG_PATH_LIST = list(Path(__file__).parent.glob("configs/*.py")) def check_data_parallel_rank(rank): @@ -50,11 +49,11 @@ def check_model_parallel_rank(rank): def check_tensor_parallel_rank(rank): - if tp_env.mode == '2d': + if tp_env.mode == "2d": check_2d_tensor_parallel_rank(rank) - elif tp_env == '2.5d': + elif tp_env == "2.5d": check_2p5d_tensor_parallel_rank(rank) - elif tp_env == '3d': + elif tp_env == "3d": check_3d_tensor_parallel_rank(rank) @@ -115,13 +114,9 @@ def check_3d_tensor_parallel_rank(rank): def init_context(config_path, rank, world_size, backend, port, host): - dist_args = dict(config=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True) + dist_args = dict( + config=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host, verbose=True + ) launch(**dist_args) check_tensor_parallel_rank(rank) @@ -134,12 +129,9 @@ def init_context(config_path, rank, world_size, backend, port, host): def run_dist(rank, world_size, port, backend, port_list, host): for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): - init_context(config_path=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=current_port, - host=host) + init_context( + config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=current_port, host=host + ) reset_seeds() @@ -158,8 +150,8 @@ def test_context(): port_list.append(port) break - spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') + spawn(run_dist, world_size, backend="gloo", port_list=port_list, host="localhost") -if __name__ == '__main__': +if __name__ == "__main__": test_context() diff --git a/tests/test_legacy/test_data/test_cifar10_dataset.py b/tests/test_legacy/test_data/test_cifar10_dataset.py index dfa9fa211ef0..4851f1b85817 100644 --- a/tests/test_legacy/test_data/test_cifar10_dataset.py +++ b/tests/test_legacy/test_data/test_cifar10_dataset.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest from torch.utils.data import DataLoader from torchvision import datasets, transforms @@ -15,7 +14,7 @@ def test_cifar10_dataset(): transform_pipeline = transforms.Compose(transform_pipeline) # build dataset - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2) @@ -23,5 +22,5 @@ def test_cifar10_dataset(): img, label = data_iter.next() -if __name__ == '__main__': +if __name__ == "__main__": test_cifar10_dataset() diff --git a/tests/test_legacy/test_data/test_data_parallel_sampler.py b/tests/test_legacy/test_data/test_data_parallel_sampler.py index cf10fe9dfa3c..1786b4a77a8b 100644 --- a/tests/test_legacy/test_data/test_data_parallel_sampler.py +++ b/tests/test_legacy/test_data/test_data_parallel_sampler.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest import torch import torch.distributed as dist from torchvision import datasets, transforms @@ -16,24 +15,26 @@ from colossalai.legacy.utils import get_dataloader from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = Config(dict( - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, -)) +CONFIG = Config( + dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + ) +) def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") colossalai.legacy.launch(**dist_args) - print('finished initialization') + print("finished initialization") # build dataset transform_pipeline = [transforms.ToTensor()] transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) @@ -50,7 +51,8 @@ def run_data_sampler(rank, world_size, port): if gpc.get_local_rank(ParallelMode.DATA) != 0: assert not torch.equal( - img, img_to_compare), 'Same image was distributed across ranks but expected it to be different' + img, img_to_compare + ), "Same image was distributed across ranks but expected it to be different" torch.cuda.empty_cache() @@ -59,5 +61,5 @@ def test_data_sampler(): spawn(run_data_sampler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_data_sampler() diff --git a/tests/test_legacy/test_data/test_deterministic_dataloader.py b/tests/test_legacy/test_data/test_deterministic_dataloader.py index 421b8d255318..abb442f48203 100644 --- a/tests/test_legacy/test_data/test_deterministic_dataloader.py +++ b/tests/test_legacy/test_data/test_deterministic_dataloader.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest import torch import torch.distributed as dist from torchvision import datasets, transforms @@ -20,8 +19,8 @@ dict( train_data=dict( dataset=dict( - type='CIFAR10', - root=Path(os.environ['DATA']), + type="CIFAR10", + root=Path(os.environ["DATA"]), train=True, download=True, ), @@ -32,17 +31,18 @@ tensor=dict(size=1, mode=None), ), seed=1024, - )) + ) +) def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") colossalai.legacy.launch(**dist_args) # build dataset transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) @@ -60,8 +60,9 @@ def run_data_sampler(rank, world_size, port): if gpc.get_local_rank(ParallelMode.DATA) != 0: # this is without sampler # this should be false if data parallel sampler to given to the dataloader - assert torch.equal(img, - img_to_compare), 'Same image was distributed across ranks and expected it to be the same' + assert torch.equal( + img, img_to_compare + ), "Same image was distributed across ranks and expected it to be the same" torch.cuda.empty_cache() @@ -70,5 +71,5 @@ def test_data_sampler(): spawn(run_data_sampler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_data_sampler() diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py index 8499784038d2..b07fe8abe86e 100644 --- a/tests/test_legacy/test_engine/test_engine.py +++ b/tests/test_legacy/test_engine/test_engine.py @@ -6,25 +6,26 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - fp16=dict(mode=None), - clip_grad_norm=1.0) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 +) -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']) -@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "repeated_computed_layers"]) +@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) def run_train(model_name, amp_mode): # FIXME: test bert get_components_func = non_distributed_component_funcs.get_callable(model_name) - gpc.config.fp16['mode'] = amp_mode + gpc.config.fp16["mode"] = amp_mode model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model = model_builder(checkpoint=False) - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer_class(model.parameters(), - lr=1e-3), - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, + optimizer=optimizer_class(model.parameters(), lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader, + ) try: engine.train() @@ -49,12 +50,9 @@ def run_train(model_name, amp_mode): def run_engine(rank, world_size, port): # init dist env - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) run_train() @@ -64,5 +62,5 @@ def test_engine(): spawn(run_engine, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_legacy/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py index 168c93c1a572..262876e0ba42 100644 --- a/tests/test_legacy/test_engine/test_gradient_accumluation.py +++ b/tests/test_legacy/test_engine/test_gradient_accumluation.py @@ -19,46 +19,40 @@ BATCH_SIZE = 2 NUM_CLASSES = 10 -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - clip_grad_norm=1.0, - gradient_accumulation=4) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), clip_grad_norm=1.0, gradient_accumulation=4 +) def run_no_pipeline(rank, world_size, port): - # init dist env - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) # build model model = resnet18(num_classes=10) # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) # build optimizer optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - logger = get_dist_logger() + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) + get_dist_logger() rank = torch.distributed.get_rank() param_track = [] grad_track = [] @@ -79,12 +73,13 @@ def run_no_pipeline(rank, world_size, port): param_track.append(next(model.parameters())[0].clone()) grad_track.append(next(model.parameters()).grad[0].clone()) step += 1 - if step == CONFIG['gradient_accumulation']: + if step == CONFIG["gradient_accumulation"]: break - assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations' - assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \ - 'param should be the same in the first few iterations and only changed in the last iteration' + assert not torch.all(grad_track[0] == grad_track[-1]), "grad should be different in different iterations" + assert torch.all(param_track[0] == param_track[1]) and not torch.all( + param_track[0] == param_track[-1] + ), "param should be the same in the first few iterations and only changed in the last iteration" gpc.destroy() torch.cuda.empty_cache() @@ -96,5 +91,5 @@ def test_engine(): spawn(run_no_pipeline, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 859707e6129d..8a9a73d65f38 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -44,7 +44,7 @@ def check_linear_col(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] @@ -65,7 +65,7 @@ def check_linear_col(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('linear_col forward: pass') + print_rank_0("linear_col forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -87,7 +87,7 @@ def check_linear_col(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_col backward: pass') + print_rank_0("linear_col backward: pass") def check_linear_row(): @@ -114,7 +114,7 @@ def check_linear_row(): W = W.clone() W.requires_grad = True - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -134,7 +134,7 @@ def check_linear_row(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -155,7 +155,7 @@ def check_linear_row(): B_grad = B_master.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_row backward: pass') + print_rank_0("linear_row backward: pass") def check_embed(): @@ -184,7 +184,7 @@ def check_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -197,7 +197,7 @@ def check_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_vocab_parallel_embed(): @@ -226,7 +226,7 @@ def check_vocab_parallel_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -239,7 +239,7 @@ def check_vocab_parallel_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -283,7 +283,7 @@ def check_classifier_no_given_weight(): C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -305,7 +305,7 @@ def check_classifier_no_given_weight(): B_grad = layer_master.bias.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -343,7 +343,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -365,7 +365,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -401,7 +401,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -416,7 +416,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -452,7 +452,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -468,7 +468,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_vocab_parallel_loss(): @@ -495,7 +495,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel loss forward: pass') + print_rank_0("vocab parallel loss forward: pass") loss.backward() loss_master.backward() @@ -503,7 +503,7 @@ def check_vocab_parallel_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel loss backward: pass') + print_rank_0("vocab parallel loss backward: pass") @torch.no_grad() @@ -531,7 +531,7 @@ def check_linear_row_stream_inference(): W = torch.chunk(W_master, DEPTH, dim=-1)[i] W = W.clone() - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -550,4 +550,4 @@ def check_linear_row_stream_inference(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py index 2a016ed7b33d..cebbedd303ee 100644 --- a/tests/test_legacy/test_layers/test_1d/test_1d.py +++ b/tests/test_legacy/test_layers/test_1d/test_1d.py @@ -10,12 +10,14 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="1d")), +) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_col() check_linear_row() @@ -39,5 +41,5 @@ def test_1d(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 494497be33e2..0bbc72eca809 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -48,7 +48,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=-1)[j] @@ -71,7 +71,7 @@ def check_linear(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -99,7 +99,7 @@ def check_linear(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -136,7 +136,7 @@ def check_layernorm(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -150,7 +150,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -181,7 +181,7 @@ def check_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -197,7 +197,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -238,7 +238,7 @@ def check_patch_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -270,7 +270,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, DEPTH)[j] bias_grad = torch.chunk(bias_grad, DEPTH)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -301,7 +301,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -317,7 +317,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -368,7 +368,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -395,7 +395,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -437,7 +437,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -463,7 +463,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH)[j] B_grad = torch.chunk(B_grad, DEPTH)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -499,7 +499,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -515,7 +515,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -552,7 +552,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -569,14 +569,14 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) criterion = CrossEntropyLoss2D() @@ -596,7 +596,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -604,7 +604,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -632,7 +632,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -641,7 +641,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 034dbe5ca29c..9c126cefeba8 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -14,10 +14,12 @@ def check_AB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -42,10 +44,22 @@ def check_AB(): out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) - out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + out = Matmul_AB_2D.apply( + A, + B, + DEPTH, + out_shape, + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True B_master = B_master.clone() @@ -55,7 +69,7 @@ def check_AB(): C = torch.chunk(C, DEPTH, dim=-1)[j] # check forward correctness check_equal(out, C) - print_rank_0('AB forward: pass') + print_rank_0("AB forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -77,15 +91,17 @@ def check_AB(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] # check backward correctness check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') + print_rank_0("AB backward: pass") def check_ABT(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -110,11 +126,22 @@ def check_ABT(): B = B.clone() B.requires_grad = True - out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + out = Matmul_ABT_2D.apply( + C, + B, + DEPTH, + (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) C_master = C_master.clone() C_master.requires_grad = True B_master = B_master.clone() @@ -123,7 +150,7 @@ def check_ABT(): A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] check_equal(out, A) - print_rank_0('ABT forward: pass') + print_rank_0("ABT forward: pass") grad_shape = A_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -144,15 +171,17 @@ def check_ABT(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') + print_rank_0("ABT backward: pass") def check_ATB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) device = get_current_device() @@ -177,21 +206,33 @@ def check_ATB(): C = C.clone() C.requires_grad = True - out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + out = Matmul_ATB_2D.apply( + A, + C, + DEPTH, + (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True C_master = C_master.clone() C_master.requires_grad = True B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] check_equal(out, B) - print_rank_0('ATB forward: pass') + print_rank_0("ATB forward: pass") grad_shape = B_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -211,4 +252,4 @@ def check_ATB(): C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') + print_rank_0("ATB backward: pass") diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py index a4b46793f19d..77a4b281a746 100644 --- a/tests/test_legacy/test_layers/test_2d/test_2d.py +++ b/tests/test_legacy/test_layers/test_2d/test_2d.py @@ -23,7 +23,9 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="2d")), +) def check_operations(): @@ -48,7 +50,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -65,5 +67,5 @@ def test_2d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2d() diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index e7a9a8be45d0..283e7f68374f 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -30,7 +30,7 @@ def check_linear(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) @@ -50,7 +50,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] @@ -60,7 +60,7 @@ def check_linear(): layer.weight = Parameter(W) layer.bias = Parameter(B) out = layer(A) - bias = layer.bias + layer.bias A_master = A_master.clone() A_master.requires_grad = True @@ -73,7 +73,7 @@ def check_linear(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -100,7 +100,7 @@ def check_linear(): if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -111,7 +111,7 @@ def check_layernorm(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) @@ -138,7 +138,7 @@ def check_layernorm(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -152,7 +152,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -160,7 +160,7 @@ def check_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -184,7 +184,7 @@ def check_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -200,7 +200,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -208,7 +208,7 @@ def check_patch_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) torch.nn.init.ones_(layer.cls_token) @@ -242,7 +242,7 @@ def check_patch_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -274,7 +274,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -282,7 +282,7 @@ def check_vocab_parallel_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -306,7 +306,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -322,7 +322,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -374,7 +374,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -401,7 +401,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -409,7 +409,7 @@ def check_vocab_parallel_classifier_no_given_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) layer = layer.to(dtype).to(device) @@ -442,7 +442,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -468,7 +468,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -476,7 +476,7 @@ def check_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -504,7 +504,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -520,7 +520,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -528,7 +528,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -557,7 +557,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -574,15 +574,15 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = CrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() @@ -601,7 +601,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -609,7 +609,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -617,7 +617,7 @@ def check_vocab_parallel_loss(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = VocabParallelCrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() @@ -637,7 +637,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -646,7 +646,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index fe78ef669bf0..992bd6107f08 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -11,10 +11,12 @@ def check_AB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -39,11 +41,23 @@ def check_AB(): B.requires_grad = True out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) - out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, - ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + out = Matmul_AB_2p5D.apply( + A, + B, + TESSERACT_DIM, + out_shape, + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True B_master = B_master.clone() @@ -53,7 +67,7 @@ def check_AB(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] # check forward correctness check_equal(out, C) - print_rank_0('AB forward: pass') + print_rank_0("AB forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -75,15 +89,17 @@ def check_AB(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] # check backward correctness check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') + print_rank_0("AB backward: pass") def check_ABT(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -109,12 +125,23 @@ def check_ABT(): B = B.clone() B.requires_grad = True - out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, - (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, - ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + out = Matmul_ABT_2p5D.apply( + C, + B, + TESSERACT_DIM, + (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) C_master = C_master.clone() C_master.requires_grad = True B_master = B_master.clone() @@ -123,7 +150,7 @@ def check_ABT(): A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] check_equal(out, A) - print_rank_0('ABT forward: pass') + print_rank_0("ABT forward: pass") grad_shape = A_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -144,15 +171,17 @@ def check_ABT(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') + print_rank_0("ABT backward: pass") def check_ATB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) device = get_current_device() @@ -178,22 +207,34 @@ def check_ATB(): C = C.clone() C.requires_grad = True - out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), - i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + out = Matmul_ATB_2p5D.apply( + A, + C, + TESSERACT_DIM, + (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True C_master = C_master.clone() C_master.requires_grad = True B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] check_equal(out, B) - print_rank_0('ATB forward: pass') + print_rank_0("ATB forward: pass") grad_shape = B_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -213,4 +254,4 @@ def check_ATB(): C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') + print_rank_0("ATB backward: pass") diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py index 38ba3ba78575..437a8f8a7265 100644 --- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py @@ -8,10 +8,12 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=4, mode='2.5d', depth=1), -),) +CONFIG = dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=4, mode="2.5d", depth=1), + ), +) def check_operations(): @@ -36,7 +38,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -53,5 +55,5 @@ def test_2p5d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2p5d() diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index 2a9dcc3cdc16..a4a4ae9a5ba4 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -73,14 +73,15 @@ def check_linear(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "linear forward: {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -93,24 +94,24 @@ def check_linear(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("linear backward: {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} linear backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} linear backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + logger.info("Rank {} linear backward (bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -157,8 +158,11 @@ def check_layernorm(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "layer norm forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True @@ -166,7 +170,7 @@ def check_layernorm(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} layernorm forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -179,22 +183,22 @@ def check_layernorm(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("layer norm backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} layernorm backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) bias_grad = norm_master.weight.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad))) + logger.info("Rank {} layernorm backward (weight_grad): {}".format(rank, check_equal(bias_grad, norm.weight.grad))) bias_grad = norm_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad))) + logger.info("Rank {} layernorm backward (bias_grad): {}".format(rank, check_equal(bias_grad, norm.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -241,14 +245,17 @@ def check_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -261,7 +268,7 @@ def check_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -269,21 +276,29 @@ def check_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, layer.weight.grad is None)) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format(rank, layer.weight.grad is None) + ) bias_grad = layer_master.bias.grad - logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -333,15 +348,18 @@ def check_vocab_parallel_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -355,8 +373,9 @@ def check_vocab_parallel_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -364,20 +383,29 @@ def check_vocab_parallel_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}".format( + rank, check_equal(A_grad, A.grad) + ) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -423,13 +451,16 @@ def check_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -442,7 +473,7 @@ def check_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -450,11 +481,15 @@ def check_classifier_given_embed_weight(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, embed.weight.grad))) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, embed.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, embed.weight.grad is None)) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format(rank, embed.weight.grad is None) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -501,14 +536,17 @@ def check_vocab_parallel_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -522,8 +560,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -532,9 +571,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - embed.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, embed.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -543,7 +582,7 @@ def check_patch_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -582,15 +621,18 @@ def check_patch_embed(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "patch embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} patch embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -604,29 +646,32 @@ def check_patch_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("patch embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) cls_grad_master = layer_master.cls_token.grad cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + logger.info("Rank {} patch embed backward (cls_grad): {}".format(rank, check_equal(cls_grad, layer.cls_token.grad))) pos_grad_master = layer_master.pos_embed.grad pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, - check_equal(pos_grad, layer.pos_embed.grad))) + logger.info( + "Rank {} patch embed backward (pos_embed_grad): {}".format(rank, check_equal(pos_grad, layer.pos_embed.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] - logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, - check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} patch embed backward (proj_weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, - check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} patch embed backward (proj_bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -635,7 +680,7 @@ def check_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -664,16 +709,17 @@ def check_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "embed forward: pass | {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -686,14 +732,14 @@ def check_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -702,7 +748,7 @@ def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -733,16 +779,19 @@ def check_vocab_parallel_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -755,7 +804,7 @@ def check_vocab_parallel_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -764,9 +813,9 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - layer.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -798,25 +847,28 @@ def check_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -825,7 +877,7 @@ def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -852,25 +904,28 @@ def check_vocab_parallel_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} vocab parallel cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} vocab parallel cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py index 2a32d8935c00..7057e2308b39 100644 --- a/tests/test_legacy/test_layers/test_3d/test_3d.py +++ b/tests/test_legacy/test_layers/test_3d/test_3d.py @@ -23,7 +23,7 @@ CONFIG = dict( parallel=dict( pipeline=1, - tensor=dict(mode='3d', size=8), + tensor=dict(mode="3d", size=8), ), seed=42, ) @@ -44,7 +44,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = True @@ -60,5 +60,5 @@ def test_3d(): spawn(check_layer_and_operation, 8) -if __name__ == '__main__': +if __name__ == "__main__": test_3d() diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py index c58445a396ec..d64ff56b8a65 100644 --- a/tests/test_legacy/test_layers/test_cache_embedding.py +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -38,10 +38,19 @@ def synthesize_1d_sparse_feature( ): indices_in_batch = batch_size * 2 indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) - offsets = torch.from_numpy( - np.array([ - 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch - ])).to(device).long() + offsets = ( + torch.from_numpy( + np.array( + [ + 0, + *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), + indices_in_batch, + ] + ) + ) + .to(device) + .long() + ) return indices, offsets @@ -89,7 +98,7 @@ def test_reorder_with_freq(): chunkid.append(idx // chunk_size) offset_in_chunk.append(idx % chunk_size) - dev = torch.device('cuda') + dev = torch.device("cuda") chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) @@ -99,31 +108,31 @@ def test_reorder_with_freq(): mgr.reorder(idx_map) indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) - mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode="floor") mgr_offsets = torch.remainder(indices, chunk_size) assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" - assert torch.allclose(offset_in_chunk, mgr_offsets), \ - f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + assert torch.allclose(offset_in_chunk, mgr_offsets), f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" @clear_cache_before_run() -@parameterize('use_LFU', [True, False]) +@parameterize("use_LFU", [True, False]) def test_freq_aware_embed(use_LFU: bool): - device = torch.device('cuda', 0) + device = torch.device("cuda", 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET - model = CachedEmbeddingBag(NUM_EMBED, - EMBED_DIM, - mode='mean', - include_last_offset=True, - cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), - ids_freq_mapping=None, - evict_strategy=evict_strategy).to(device) + model = CachedEmbeddingBag( + NUM_EMBED, + EMBED_DIM, + mode="mean", + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy, + ).to(device) assert model.weight.shape[0] == NUM_EMBED - ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), - mode='mean', - include_last_offset=True, - freeze=False) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + model.weight.detach().to(device), mode="mean", include_last_offset=True, freeze=False + ) assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) @@ -149,22 +158,25 @@ def test_freq_aware_embed(use_LFU: bool): model.cache_weight_mgr.flush() model_weight = model.weight.detach().to(device) ref_weight = ref_model.weight.detach() - assert torch.allclose(model_weight, ref_weight), \ - f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + assert torch.allclose( + model_weight, ref_weight + ), f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" @clear_cache_before_run() -@parameterize('init_freq', [True, False]) +@parameterize("init_freq", [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior - Bag = CachedEmbeddingBag(5, - 5, - cache_ratio=3 / 5, - buffer_size=0, - pin_weight=True, - ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, - warmup_ratio=1.0, - evict_strategy=EvictionStrategy.LFU) + Bag = CachedEmbeddingBag( + 5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU, + ) # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) offsets = torch.tensor([0], device="cuda:0") @@ -189,14 +201,15 @@ def test_lfu_strategy(init_freq: bool): # check strategy Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit - assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ - "LFU strategy behavior failed" + assert torch.allclose( + torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1]) + ), "LFU strategy behavior failed" def gather_tensor(tensor, rank, world_size): @@ -211,7 +224,7 @@ def gather_tensor(tensor, rank, world_size): def run_parallel_freq_aware_embed_tablewise(rank, world_size): if world_size != 2: return - device = torch.device('cuda', torch.cuda.current_device()) + device = torch.device("cuda", torch.cuda.current_device()) # initialize weight # 3 feature tables. idx: 0~5, 6~10, 11~17 @@ -221,20 +234,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): weight_table3 = weight_tables[11:18] embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=6, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table1.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu() + ) + ) embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=5, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table2.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu() + ) + ) embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=7, - cuda_row_num=4, - assigned_rank=1, - initial_weight=weight_table3.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu() + ) + ) if rank == 0: _weight = torch.cat([weight_table1, weight_table2], 0) else: @@ -249,7 +262,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): evict_strategy=EvictionStrategy.LFU, ) # explain - ''' + """ batch feature 1 feature 2 feature 3 input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] @@ -257,10 +270,12 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format - ''' - res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), - already_split_along_rank=False) + """ + res = model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False, + ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) if rank == 0: @@ -273,13 +288,15 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): # check correctness if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), - include_last_offset=True, - freeze=False).to(device) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight_tables.detach().clone(), include_last_offset=True, freeze=False + ).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) - ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + ref_res = ref_model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + ) ref_res.backward(ref_fake_grad) ref_optimizer.step() ref_optimizer.zero_grad() @@ -291,7 +308,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): def run_parallel_freq_aware_embed_columnwise(rank, world_size): - device = torch.device('cuda', torch.cuda.current_device()) + device = torch.device("cuda", torch.cuda.current_device()) num_embed = 100 embed_dim = 16 @@ -313,19 +330,20 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): cache_ratio=batch_size * 2 / num_embed, ) - assert model.cache_weight_mgr.weight.device.type == 'cpu' + assert model.cache_weight_mgr.weight.device.type == "cpu" assert model.cache_weight_mgr.cuda_cached_weight.requires_grad weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") - assert torch.allclose(weight_in_rank, - model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" + assert torch.allclose( + weight_in_rank, model.cache_weight_mgr.weight.detach() + ), f"{weight_in_rank - model.cache_weight_mgr.weight}" optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), - include_last_offset=True, - freeze=False).to(device) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight.detach().clone(), include_last_offset=True, freeze=False + ).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) set_seed(4321) @@ -360,19 +378,19 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # run_parallel_freq_aware_embed_columnwise(rank, world_size) run_parallel_freq_aware_embed_tablewise(rank, world_size) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_parallel_freq_aware_embed(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": # test_freq_aware_embed(True) test_parallel_freq_aware_embed(2) # test_lfu_strategy(False) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index ac9493adab2e..aa4d5d6ceeb3 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -16,6 +16,7 @@ def check_selfattention(): layer = layer.to(get_current_device()) hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) - attention_mask = torch.randint(low=0, high=2, - size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) - out = layer(hidden_states, attention_mask) + attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( + get_current_device() + ) + layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py index 85226f9d934a..bdd3e04c6479 100644 --- a/tests/test_legacy/test_layers/test_sequence/test_sequence.py +++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py @@ -8,7 +8,7 @@ from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) +CONFIG = dict(parallel=dict(tensor=dict(size=4, mode="sequence"))) def check_ring_qk(rank, world_size): @@ -26,8 +26,8 @@ def check_ring_qk(rank, world_size): dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) # create distributed tensors - sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_q = q.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_k = k.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() # set autograd attributes q.requires_grad = True @@ -47,7 +47,7 @@ def check_ring_qk(rank, world_size): sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) # check master and distributed attention scores - sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + sub_master_a = a[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) # run master backward @@ -55,13 +55,12 @@ def check_ring_qk(rank, world_size): a.mean().backward() # run distributed backward - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] torch.autograd.backward(sub_a, partial_master_a_grad) # check master and distributed grads - partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ - 'attention score cannot match' + partial_master_q_grad = q.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), "attention score cannot match" def check_ring_av(rank, world_size): @@ -79,8 +78,8 @@ def check_ring_av(rank, world_size): dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) # create distributed tensors - sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_a = a.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_v = v.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() # set autograd attributes a.requires_grad = True @@ -102,7 +101,7 @@ def check_ring_av(rank, world_size): # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') # check master and distributed output - sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + sub_master_out = out[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) # # run master backward @@ -110,17 +109,16 @@ def check_ring_av(rank, world_size): out.mean().backward() # # run distributed backward - partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + partial_master_out_grad = out.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] torch.autograd.backward(sub_out, partial_master_out_grad) # # check master and distributed grads - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ - 'attention output cannot match' + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), "attention output cannot match" def run_test(rank, world_size, port): - colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) + colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host="localhost", port=port) # check_ring_qk(rank, world_size) check_ring_av(rank, world_size) @@ -135,5 +133,5 @@ def test_sequence(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_sequence() diff --git a/tests/test_legacy/test_pipeline/rpc_test_utils.py b/tests/test_legacy/test_pipeline/rpc_test_utils.py index 9a336c4224be..e59f22062cfc 100644 --- a/tests/test_legacy/test_pipeline/rpc_test_utils.py +++ b/tests/test_legacy/test_pipeline/rpc_test_utils.py @@ -3,12 +3,10 @@ import warnings import torch -import torch.distributed as dist import torch.distributed.rpc as rpc import torch.multiprocessing as mp from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set -from torch.optim import SGD, Adam, Optimizer, RMSprop from colossalai.legacy import launch from colossalai.legacy.pipeline.pipeline_process_group import ppg @@ -17,13 +15,12 @@ rpc_is_initialized = _is_current_rpc_agent_set -def color_debug(text, prefix=' ', color='blue'): +def color_debug(text, prefix=" ", color="blue"): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) class MLP(nn.Module): - def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -38,7 +35,6 @@ def forward(self, x): class DAG_MLP(nn.Module): - def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -55,12 +51,11 @@ def forward(self, x, y): class RpcTestModel(nn.Module): - def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: super().__init__() self.rank = stage_id self.is_last_rank = stage_id == actual_stage_num - 1 - self.linear_name = f'linear_{stage_id}' + self.linear_name = f"linear_{stage_id}" if stage_id == 0: linear = nn.Linear(feat_num, h) @@ -82,38 +77,38 @@ def forward(self, x) -> torch.Tensor: def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=str, default=128) return parser.parse_args() def pg_parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=4) - parser.add_argument('--dp_degree', type=int, default=2) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--num_worker_threads', type=str, default=128) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument("--world_size", type=int, default=4) + parser.add_argument("--dp_degree", type=int, default=2) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--num_worker_threads", type=str, default=128) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") return parser.parse_args() def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -122,17 +117,19 @@ def run_worker(rank, args, master_func): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" disable_existing_loggers() launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) # in rpc mode, only rank 0 is needed to be coded if rank == 0: diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py index 3bff08318d40..f6c077136607 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py @@ -4,7 +4,6 @@ from torch import nn from colossalai.legacy.pipeline.rpc import ChimeraPipelineEngine -from colossalai.testing import assert_close # global variable for model created feat_num = 100 @@ -20,7 +19,7 @@ def partition(pp_rank: int, chunk: int, stage_num: int): def run_master(args): torch.manual_seed(100) - epoch = args.epoch + args.epoch device = args.device stage_num = args.world_size chunk = 1 @@ -32,11 +31,13 @@ def run_master(args): assert sample_num % batch_size == 0 - engine = ChimeraPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - checkpoint=use_checkpoint) + engine = ChimeraPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint, + ) engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) input_sample = torch.randn((sample_num, feat_num), device=device) @@ -56,7 +57,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) # input_sample = input_sample[len(input_sample) // 2:] input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py index eff031ff8faa..806f24a64511 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py @@ -1,9 +1,9 @@ import torch from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import autograd, nn -from torch.optim import SGD, Adam, Optimizer, RMSprop +from torch.optim import Optimizer -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import assert_close # global variable for model created @@ -36,12 +36,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) engine.initialize_optimizer(optimizer_class, lr=lr) @@ -59,7 +61,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py index 1a6077f8d3e9..a5e8fc6e6b51 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,8 +1,7 @@ import torch from rpc_test_utils import RpcTestModel, parse_args, rpc_run -from torch import nn -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine # global variable for model created feat_num = 100 @@ -32,12 +31,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) for _ in range(epoch): _ = engine.forward_backward(input_sample, forward_only=False) diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py index 43966ce3dbda..09c9b84a9907 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py @@ -2,7 +2,7 @@ from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import autograd, nn -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import assert_close feat_num = 100 @@ -32,12 +32,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) forward_result = engine.forward_backward(input_sample) @@ -54,7 +56,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() autograd.backward(out_val) diff --git a/tests/test_legacy/test_pipeline/test_middleware_1f1b.py b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py index 4e43d52f8aee..dff04c3ebba1 100644 --- a/tests/test_legacy/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py @@ -25,7 +25,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) annotated_model = balanced_split_pass(gm, stage_num) @@ -33,7 +33,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -47,11 +47,11 @@ def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) epoch = 3 - device = 'cuda' + device = "cuda" stage_num = world_size chunk = 1 num_microbatches = 8 - use_checkpoint = 'store_true' + use_checkpoint = "store_true" if model_cls == MLP: @@ -92,29 +92,26 @@ def data_gen(): checkpoint=use_checkpoint, ) if not forward_only: - engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + engine.initialize_optimizer(getattr(torch.optim, "SGD"), lr=1e-3) for _ in range(epoch): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) - logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + logits = engine.forward_backward({"x": input_x, "y": input_y}, labels=labels, forward_only=forward_only) def run_worker(rank, world_size, port, model_cls, forward_only, master_func): - master_addr = 'localhost' + master_addr = "localhost" master_port = 29020 - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = str(master_port) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) disable_existing_loggers() - launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=1, - tp_degree=1, - num_worker_threads=128, - device='cuda') + launch(dict(), rank, world_size, master_addr, master_port, "nccl", verbose=False) + ppg.set_global_info( + rank=rank, world_size=world_size, dp_degree=1, tp_degree=1, num_worker_threads=128, device="cuda" + ) # in rpc mode, only rank 0 is needed to be coded if rank == 0: @@ -125,8 +122,8 @@ def run_worker(rank, world_size, port, model_cls, forward_only, master_func): @pytest.mark.skip("skip due to CI torch version 1.11") -@parameterize('model_cls', [MLP, DAG_MLP]) -@parameterize('forward_only', [True, False]) +@parameterize("model_cls", [MLP, DAG_MLP]) +@parameterize("forward_only", [True, False]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_pp_middleware_fwd(model_cls, forward_only): diff --git a/tests/test_legacy/test_pipeline/test_pipelinable.py b/tests/test_legacy/test_pipeline/test_pipelinable.py index 2ba5d0aa24d8..950cc68036ae 100644 --- a/tests/test_legacy/test_pipeline/test_pipelinable.py +++ b/tests/test_legacy/test_pipeline/test_pipelinable.py @@ -2,14 +2,13 @@ import torch from colossalai.legacy.pipeline.pipelinable import PipelinableContext -from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn NUM_CHUNKS = 1 PIPELINE_SIZE = 2 class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): super().__init__() intermediate_dim = dim * 4 @@ -55,5 +54,5 @@ def test_pipelinable(): spawn(run_pipelinable, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_pipelinable() diff --git a/tests/test_legacy/test_pipeline/test_pipeline_process_group.py b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py index e6b95660279b..627aafb18e61 100644 --- a/tests/test_legacy/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py @@ -10,8 +10,8 @@ def run_worker(rank, args): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -20,17 +20,19 @@ def run_worker(rank, args): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" disable_existing_loggers() launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) if rpc_is_initialized(): rpc.shutdown() diff --git a/tests/test_legacy/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py index b6fea28e4c8a..78bea6658364 100644 --- a/tests/test_legacy/test_tensor/common_utils/_utils.py +++ b/tests/test_legacy/test_tensor/common_utils/_utils.py @@ -13,7 +13,7 @@ def set_seed(seed): random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -27,12 +27,12 @@ def check_equal(A, B): def replace_parameter_add_grad(layer, weight=None, bias=None): if weight is not None: - delattr(layer, 'weight') - setattr(layer, 'weight', weight) + delattr(layer, "weight") + setattr(layer, "weight", weight) layer.weight.requires_grad = True if bias is not None: - delattr(layer, 'bias') - setattr(layer, 'bias', bias) + delattr(layer, "bias") + setattr(layer, "bias", bias) layer.bias.requires_grad = True @@ -47,12 +47,9 @@ def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: return True -def tensor_shard_equal(tensor: torch.Tensor, - shard: torch.Tensor, - rank: int, - world_size: int, - rtol: float = 1e-3, - atol: float = 1e-1): +def tensor_shard_equal( + tensor: torch.Tensor, shard: torch.Tensor, rank: int, world_size: int, rtol: float = 1e-3, atol: float = 1e-1 +): assert tensor.ndim == shard.ndim if tensor.shape == shard.shape: return tensor_equal(tensor, shard, rtol, atol) diff --git a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index b6d6bcee66ce..506244447054 100644 --- a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -48,17 +48,17 @@ def check_mem(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_mem() run() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_dist_spec_mgr(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dist_spec_mgr(4) diff --git a/tests/test_legacy/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py index 7a8694ff6789..5217e22cc422 100644 --- a/tests/test_legacy/test_tensor/test_parameter.py +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -9,26 +9,27 @@ @pytest.mark.skip def test_multiinheritance(): - colossalai.legacy.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl") colo_param = ColoParameter(None, requires_grad=True) - assert colo_param.dist_spec.placement.value == 'r' + assert colo_param.dist_spec.placement.value == "r" assert isinstance(colo_param, ColoTensor) assert isinstance(colo_param, torch.nn.Parameter) # __deepcopy__ overload import copy + colo_param2 = copy.deepcopy(colo_param) assert isinstance(colo_param2, ColoParameter) assert tensor_equal(colo_param.data, colo_param2.data) assert colo_param.requires_grad == colo_param2.requires_grad # __repr__ overload - assert 'ColoParameter' in str(colo_param) + assert "ColoParameter" in str(colo_param) # __torch_function__ clone_param = torch.clone(colo_param) assert isinstance(clone_param, ColoTensor) -if __name__ == '__main__': +if __name__ == "__main__": test_multiinheritance() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index 84652093a9fd..a5a2d38577dc 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -8,12 +8,10 @@ from colossalai.legacy.communication import ( recv_backward, recv_forward, - recv_obj_meta, send_backward, send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_obj_meta, ) from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc @@ -39,10 +37,10 @@ def check_forward(output_tensor, rank, logger): tensor = output_tensor.clone() else: tensor = recv_forward(output_tensor.shape) - logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor))) + logger.info("Rank {} received forward. Correct tensor: {}".format(rank, check_equal(tensor, output_tensor))) if not gpc.is_last_rank(ParallelMode.PIPELINE): send_forward(tensor) - logger.info('Rank {} sent forward.'.format(rank)) + logger.info("Rank {} sent forward.".format(rank)) def check_backward(output_grad, rank, logger): @@ -51,22 +49,26 @@ def check_backward(output_grad, rank, logger): grad = output_grad.clone() else: grad = recv_backward(output_grad.shape) - logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad))) + logger.info("Rank {} received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad))) if not gpc.is_first_rank(ParallelMode.PIPELINE): send_backward(grad) - logger.info('Rank {} sent backward.'.format(rank)) + logger.info("Rank {} sent backward.".format(rank)) def check_forward_backward(output_tensor, output_grad, rank, logger): dist.barrier() if not gpc.is_first_rank(ParallelMode.PIPELINE): tensor = send_backward_recv_forward(output_grad, output_tensor.shape) - logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format( - rank, check_equal(tensor, output_tensor))) + logger.info( + "Rank {} sent backward received forward. Correct tensor: {}".format( + rank, check_equal(tensor, output_tensor) + ) + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): grad = send_forward_recv_backward(output_tensor, output_grad.shape) - logger.info('Rank {} sent forward received backward. Correct grad: {}'.format( - rank, check_equal(grad, output_grad))) + logger.info( + "Rank {} sent forward received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad)) + ) def check_comm(size, rank, prev_rank, next_rank, logger): @@ -84,13 +86,13 @@ def check_comm(size, rank, prev_rank, next_rank, logger): def run_check(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") logger = get_dist_logger() rank = gpc.get_global_rank() prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank)) - logger.info('Distributed environment is initialized.') + logger.info("Rank {0}: prev rank {1}, next rank {2}".format(rank, prev_rank, next_rank)) + logger.info("Distributed environment is initialized.") check_comm(world_size, rank, prev_rank, next_rank, logger) gpc.destroy() @@ -104,5 +106,5 @@ def test_p2p(): spawn(run_check, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_p2p() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py index fd94c279b6fb..cd7fcfe5635d 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -23,7 +23,7 @@ def run_schedule(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model model = resnet18(num_classes=10) @@ -33,20 +33,23 @@ def run_schedule(rank, world_size, port): elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: class Flatten(nn.Module): - def forward(self, x): return torch.flatten(x, 1) model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) - print_rank_0('model is created') - - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) + print_rank_0("model is created") + + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ] + ), + ) train_dataloader = get_dataloader( dataset=train_dataset, @@ -83,5 +86,5 @@ def test_pipeline_schedule(): spawn(run_schedule, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_schedule() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py index 4a240533474c..d19b12a5b044 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -16,16 +16,15 @@ CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model']) +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) def run_trainer(model_name): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() optimizer = optimizer_class(model.parameters(), lr=1e-3) - engine, train_dataloader, *_ = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *_ = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) @@ -35,22 +34,21 @@ def run_trainer(model_name): logger.info("trainer is built", ranks=[0]) logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5, + ) torch.cuda.empty_cache() def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) @pytest.mark.dist @@ -60,5 +58,5 @@ def test_trainer_no_pipeline(): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_trainer_no_pipeline() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py index 521b2f32f22d..0b34a79f96dd 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py @@ -29,12 +29,9 @@ def run_trainer_with_pipeline(rank, world_size, port): - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) # build model model = resnet18(num_classes=10) @@ -44,35 +41,35 @@ def run_trainer_with_pipeline(rank, world_size, port): elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: class Flatten(nn.Module): - def forward(self, x): return torch.flatten(x, 1) model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ] + ), + ) + + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) # build optimizer optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) @@ -82,11 +79,9 @@ def forward(self, x): logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) + trainer.fit( + train_dataloader=train_dataloader, epochs=NUM_EPOCHS, max_steps=3, display_progress=True, test_interval=5 + ) gpc.destroy() torch.cuda.empty_cache() @@ -98,5 +93,5 @@ def test_trainer_with_pipeline(): spawn(run_trainer_with_pipeline, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_trainer_with_pipeline() diff --git a/tests/test_legacy/test_utils/test_activation_checkpointing.py b/tests/test_legacy/test_utils/test_activation_checkpointing.py index 19984ae120b5..3303f610ee82 100644 --- a/tests/test_legacy/test_utils/test_activation_checkpointing.py +++ b/tests/test_legacy/test_utils/test_activation_checkpointing.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import pytest import torch import torch.nn.functional as F @@ -44,20 +43,19 @@ def forward_inplace(x, weight): @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload, use_reentrant): - # as seed manager is singleton # if we don't reset seeds here, # other tests might affect this test reset_seeds() # We put initialization here to avoid change cuda rng state below - inputs = torch.rand(2, 2, requires_grad=True, device='cuda') - weight = torch.rand(2, 4, requires_grad=True, device='cuda') + inputs = torch.rand(2, 2, requires_grad=True, device="cuda") + weight = torch.rand(2, 4, requires_grad=True, device="cuda") # Get a copy of input tensors - inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_ = torch.empty(2, 2, requires_grad=True, device="cuda") inputs_.data.copy_(inputs.data) - weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_ = torch.empty(2, 4, requires_grad=True, device="cuda") weight_.data.copy_(weight.data) add_seed(ParallelMode.GLOBAL, 1024) @@ -83,7 +81,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # Extra test for use_reentrant=False @@ -110,7 +108,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # as seed manager is singleton diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py index 88cd89a217fe..c07ff132b79e 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py index 591cd714fc65..2ec1facf21b1 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py index b165b4276f10..a6bf702a8482 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py index 2ce054d33b2d..12d928312969 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + config = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 2e25dc773b68..9416ac86e325 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() @@ -24,5 +24,5 @@ def test_memory_utils(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_memory_utils(world_size=2) diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index 918f174aba76..b5f2be705890 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -28,20 +28,20 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] else: grad = p.grad - assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' + assert torch.allclose(grad, colo_p.grad), f"diff: {torch.abs(grad - colo_p.grad)}" -@parameterize('dtype', [torch.float]) -@parameterize('device', ['mixed', 'cuda', 'cpu']) -@parameterize('norm_type', [2.0, 3.0, float('inf')]) +@parameterize("dtype", [torch.float]) +@parameterize("device", ["mixed", "cuda", "cpu"]) +@parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): - print(f'{world_size}, {dtype}, {device}, {norm_type}') + print(f"{world_size}, {dtype}, {device}, {norm_type}") cuda_device = get_current_device() devices = [cuda_device] * 4 - if device == 'cpu': - devices = [torch.device('cpu')] * 4 - elif device == 'mixed': - devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 + if device == "cpu": + devices = [torch.device("cpu")] * 4 + elif device == "mixed": + devices = [cuda_device] * 2 + [torch.device("cpu")] * 2 pg = ProcessGroup(tp_degree=world_size) params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] colo_params = [ @@ -55,24 +55,24 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty shard_param(colo_params[2]) torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) - assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' + assert close(torch_norm, colo_norm), f"diff: {abs(torch_norm-colo_norm)}" for p, colo_p in zip(params, colo_params): check_grad_equal(p, colo_p) def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_grad_clip_norm(world_size=world_size) @pytest.mark.skip("this need to be updated") @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_clip_grad(world_size: int): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_clip_grad(2) diff --git a/tests/test_legacy/test_zero/test_commons.py b/tests/test_legacy/test_zero/test_commons.py index 42a9f1eecb95..741f519e1376 100644 --- a/tests/test_legacy/test_zero/test_commons.py +++ b/tests/test_legacy/test_zero/test_commons.py @@ -7,29 +7,29 @@ def run_tensor_move(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host="localhost", port=port, backend="nccl") src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" src_t = torch.ones(2, 3) tgt_t = torch.zeros(2, 3).cuda().half() colo_model_data_tensor_move(src_t, tgt_t) # the src_t has been removed - assert (src_t.numel() == 0) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert src_t.numel() == 0 + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" src_t = ShardedTensor(torch.ones(2, 3)) tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert torch.sum(tgt_t.payload) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" - assert (tgt_t.device.type == 'cuda') - colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) - assert (tgt_t.device.type == 'cpu') + assert tgt_t.device.type == "cuda" + colo_model_data_tensor_move_inline(tgt_t, torch.device("cpu")) + assert tgt_t.device.type == "cpu" @rerun_if_address_is_in_use() @@ -37,5 +37,5 @@ def test_tensor_move(): spawn(run_tensor_move, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_tensor_move() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 9c84a99cd549..8742e5f41136 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -17,11 +17,11 @@ def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE initialization + MOE_CONTEXT.setup(42) # MOE initialization noisy_func = UniformNoiseGenerator() router = Top1Router(noisy_func=noisy_func) num_experts_list = [1, 2, 4] @@ -67,5 +67,5 @@ def test_grad_handler(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index c096b6075005..7a9c551d679d 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -23,12 +23,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization MOE_CONTEXT.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) @@ -46,7 +46,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f old_out, _ = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) - old_out.backward(grad) # get gradient + old_out.backward(grad) # get gradient # save all results o_tk_grad = tokens.grad.data.clone() @@ -57,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out, _ = layer(tokens) # get outputs through colossal kernel + new_out, _ = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) @@ -65,7 +65,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f check_equal(old_out, new_out, 1e-2) # forward function passed - new_out.backward(grad) # get new type gradient + new_out.backward(grad) # get new type gradient n_tk_grad = tokens.grad.data.clone() n_gt_grad = layer.gate_weight.grad.data.clone() @@ -92,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, router): spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_kernel(2, 256, torch.float16, Top2Router) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8a0283ba71fc..b7024f32b1cf 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -17,11 +17,11 @@ def exam_moe_checkpoint(): with ColoInitContext(device=get_current_device()): model = MoeModel(checkpoint=True) - save_moe_model(model, 'temp_path.pth') + save_moe_model(model, "temp_path.pth") with ColoInitContext(device=get_current_device()): other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, 'temp_path.pth') + load_moe_model(other_model, "temp_path.pth") state_0 = model.state_dict() state_1 = other_model.state_dict() @@ -30,11 +30,11 @@ def exam_moe_checkpoint(): assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove('temp_path.pth') + os.remove("temp_path.pth") def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_checkpoint() @@ -46,5 +46,5 @@ def test_moe_checkpoint(world_size): spawn(_run_dist) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_checkpoint(world_size=4) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 555338fcf9fc..488573b733b1 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -9,17 +9,16 @@ from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print from tests.test_zero.test_legacy.common import CONFIG -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) def exam_moe_colo_init(init_device_type): world_size = dist.get_world_size() - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") @@ -40,7 +39,7 @@ def exam_moe_colo_init(init_device_type): def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_colo_init() @@ -52,5 +51,5 @@ def test_moe_colo_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 6dc3f5f18b6d..300fb6c99b7b 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -16,11 +16,11 @@ def run_test(rank, world_size, port): world_size = 4 - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization exp0 = Experts(expert_module, 1, **expert_factor) exp1 = Experts(expert_module, 2, **expert_factor) exp2 = Experts(expert_module, 4, **expert_factor) @@ -64,5 +64,5 @@ def test_moe_initialization(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_initialization() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 79722f9f4056..c48f9a3557ce 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -15,20 +15,15 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): - class TestSubModule(CheckpointModule): - def __init__(self): super().__init__(checkpoint) expert_cls = nn.Linear expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) + self.moe = MoeModule( + dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict + ) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -50,49 +45,52 @@ def forward(self, x): return x -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") + get_dist_logger("test_moe_zero_init") - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): + with ZeroInitContext( + target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor, + ): model = MoeModel(checkpoint=True) for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + if ("experts" in name) or ("gate" in name) or ("residual_combine" in name): assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) else: assert param.colo_attr.sharded_data_tensor.is_sharded # the parameters in moe experts is not replicated - if 'experts' in name: + if "experts" in name: assert not param.colo_attr.is_replicated else: assert param.colo_attr.is_replicated if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + assert ( + param.colo_attr.data_payload.device.type == init_device.type + ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}" else: - assert param.colo_attr.data_payload.device.type == 'cuda' + assert param.colo_attr.data_payload.device.type == "cuda" def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_moe_zero_init() @@ -104,5 +102,5 @@ def test_moe_zero_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 595d4374df6f..724d70d77bc6 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -21,13 +21,13 @@ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy = shard_strategy_class() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True + ): zero_model = MoeModel(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy) @@ -54,7 +54,7 @@ def run_model_test(enable_autocast, shard_strategy_class): def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_model_test() @@ -66,5 +66,5 @@ def test_moe_zero_model(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 35fde6f10f3f..bb9822daee05 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -43,31 +43,33 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler): @parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug @parameterize("reuse_fp16_shard", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, - shard_strategy_class, - use_cpuadam, - reuse_fp16_shard, - gpu_margin_mem_ratio=0.0): +def _run_test_sharded_optim_v2( + cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0 +): shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cpu") if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True, + ): zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=reuse_fp16_shard) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy="cpu" if cpu_offload else "cuda", + reuse_fp16_shard=reuse_fp16_shard, + ) # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): @@ -82,12 +84,11 @@ def _run_test_sharded_optim_v2(cpu_offload, optimizer_class = CPUAdam optim = optimizer_class(model.parameters(), lr=1e-3) sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) + sharded_optim = ShardedOptimizerV2( + zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio + ) - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_grad_handler = MoeGradientHandler(model) @@ -103,7 +104,7 @@ def _run_test_sharded_optim_v2(cpu_offload, def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) _run_test_sharded_optim_v2() @@ -116,5 +117,5 @@ def test_moe_zero_optim(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=4) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 2186a421fe00..8131ea3234d8 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -10,16 +10,25 @@ from colossalai.utils import get_current_device, multi_tensor_applier -_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), - (torch.bfloat16, torch.bfloat16)] - -_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half)] +_FUSED_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), + (torch.bfloat16, torch.float), + (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16), +] + +_CPU_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), +] class AdamKernel: - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: self.lr = lr self.beta1 = beta1 @@ -34,7 +43,6 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class TorchAdamKernel(AdamKernel): - def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): bias_correction1 = 1 - self.beta1**step bias_correction2 = 1 - self.beta2**step @@ -57,36 +65,68 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class FusedAdamKernel(AdamKernel): - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): - multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]], - self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay, - -1) + multi_tensor_applier( + self.fused_adam, + self.dummy_overflow_buf, + [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, + self.beta1, + self.beta2, + self.eps, + step, + self.use_adamw, + True, + self.weight_decay, + -1, + ) class CPUAdamKernel(AdamKernel): - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): - self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1), - grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1) - - -def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype, - g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float): + self.cpu_adam_op.step( + step, + self.lr, + self.beta1, + self.beta2, + self.eps, + self.weight_decay, + True, + param.view(-1), + grad.view(-1), + exp_avg.view(-1), + exp_avg_sq.view(-1), + -1, + ) + + +def check_adam_kernel( + kernel: Type[AdamKernel], + adamw: bool, + weight_decay: float, + p_dtype: torch.dtype, + g_dtype: torch.dtype, + device: torch.device, + n_steps: int, + rtol: float, + atol: float, +): lr = 1e-3 beta1, beta2 = 0.9, 0.999 eps = 1e-8 @@ -109,9 +149,9 @@ def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) -@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES) def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-5, 1e-8 if p_dtype is torch.float16 or g_dtype is torch.float16: @@ -121,11 +161,11 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) -@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES) def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-5, 1e-8 if p_dtype is torch.float16 or g_dtype is torch.float16: rtol, atol = 1e-3, 1e-3 - check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol) + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 0f72bc134809..59b40a0afa3c 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -10,17 +10,17 @@ from tests.kit.model_zoo import model_zoo _ALLOWED_OPTIM_DEVICES = [ - (FusedAdam, torch.device('cuda:0')), - (CPUAdam, torch.device('cpu')), - (CPUAdam, torch.device('cuda:0')), - (HybridAdam, torch.device('cpu')), - (HybridAdam, torch.device('cuda:0')), + (FusedAdam, torch.device("cuda:0")), + (CPUAdam, torch.device("cpu")), + (CPUAdam, torch.device("cuda:0")), + (HybridAdam, torch.device("cpu")), + (HybridAdam, torch.device("cuda:0")), ] _ALLOWED_P_G_TYPES = [ - (torch.float, torch.float), # pure fp32 - (torch.float, torch.half), # fp16 amp - (torch.float, torch.bfloat16), # bfloat16 amp + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] @@ -53,12 +53,17 @@ def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> p.data = orig_p -@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) -def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, - adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None: - model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values())) +@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert( + optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], + device: torch.device, + adamw: bool, + p_dtype: torch.dtype, + g_dtype: torch.dtype, +) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values())) torch_model = model_fn().to(device) model = deepcopy(torch_model).to(p_dtype) lr = 1e-3 diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 5d794ac2dd1a..a68a9c51855f 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai.nn.optimizer import CPUAdam, HybridAdam @@ -15,23 +14,22 @@ def move_some_params_to_cuda(model, torch_model): def check_params_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' + assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" @clear_cache_before_run() -@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) -@parameterize('nvme_offload_dir', ['./offload', None]) -@parameterize('adam_cls', [CPUAdam, HybridAdam]) +@parameterize("nvme_offload_fraction", [0.0, 0.5, 1.0]) +@parameterize("nvme_offload_dir", ["./offload", None]) +@parameterize("adam_cls", [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') + get_components_func = non_distributed_component_funcs.get_callable("simple_net") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() torch_model = model_builder() move_some_params_to_cuda(model, torch_model) - optimizer = adam_cls(model.parameters(), - lr=0.1, - nvme_offload_fraction=nvme_offload_fraction, - nvme_offload_dir=nvme_offload_dir) + optimizer = adam_cls( + model.parameters(), lr=0.1, nvme_offload_fraction=nvme_offload_fraction, nvme_offload_dir=nvme_offload_dir + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1) with torch.no_grad(): for p, torch_p in zip(model.parameters(), torch_model.parameters()): @@ -45,5 +43,5 @@ def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): check_params_equal(model, torch_model) -if __name__ == '__main__': - test_nvme_adam(0.5, './offload', CPUAdam) +if __name__ == "__main__": + test_nvme_adam(0.5, "./offload", CPUAdam) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 71946f6b988a..1665711ceeef 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -22,30 +22,30 @@ def check_p2p_communication(): if rank == 0: p2p.send_forward(tensor) p2p.send_forward([tensor]) - p2p.send_forward({'tensor': tensor}) + p2p.send_forward({"tensor": tensor}) else: obj = p2p.recv_forward() assert torch.equal(obj, tensor) obj = p2p.recv_forward() assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) obj = p2p.recv_forward() - assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) if rank == 1: p2p.send_backward(tensor) p2p.send_backward([tensor]) - p2p.send_backward({'tensor': tensor}) + p2p.send_backward({"tensor": tensor}) else: obj = p2p.recv_backward() assert torch.equal(obj, tensor) obj = p2p.recv_backward() assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) obj = p2p.recv_backward() - assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_p2p_communication() @@ -55,5 +55,5 @@ def test_pipeline_p2p(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_p2p() diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 0cbb852b97a0..3723c9c1014a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -4,36 +4,42 @@ def test_t5_pipeline_distribution(): num_test_cases = 8 test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage def test_t5_pipeline_layers(): num_test_cases = 4 test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], } for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, - decoder_starting_stage) + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) assert start_idx == predicted_start assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 395519e97898..f6be8f6feac2 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -4,41 +4,47 @@ def test_whisper_pipeline_distribution(): num_test_cases = 8 test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } for i in range(num_test_cases): - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage def test_whisper_pipeline_layers(): num_test_cases = 4 test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], } for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) - - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, - decoder_starting_stage) + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) assert start_idx == predicted_start assert end_idx == predicted_end -if __name__ == '__main__': +if __name__ == "__main__": test_whisper_pipeline_distribution() test_whisper_pipeline_layers() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index a995d17e5da6..f181453eaed5 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -16,7 +16,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(4, 8) @@ -40,19 +39,20 @@ def forward(self, x): return x -def pp_linear_fwd(forward, - data: torch.Tensor = None, - input_obj: torch.Tensor = None, - stage_mgr: PipelineStageManager = None, - num_chunks: int = None, - model_chunk_id: int = None): - +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None, +): if stage_mgr.is_first_stage() and model_chunk_id == 0: - return {'input_obj': forward(data)} + return {"input_obj": forward(data)} elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: return forward(input_obj) else: - return {'input_obj': forward(input_obj)} + return {"input_obj": forward(input_obj)} @parameterize("num_micro_batches", [4, 8, 12]) @@ -84,10 +84,11 @@ def examine_pp(num_micro_batches): if idx % (world_size) == local_rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( - partial(pp_linear_fwd, - stage_mgr=stage_manager, - num_chunks=NUM_CHUNKS, - model_chunk_id=len(sharded_model)), sub_model._forward) + partial( + pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) + ), + sub_model._forward, + ) sharded_model.append(sub_model.cuda()) # create optimizer @@ -109,16 +110,13 @@ def examine_pp(num_micro_batches): torch_loss = criterion(torch_output, _) torch_loss.backward() - pp_ret = schedule.forward_backward_step(sharded_model, - iter(input_list), - criterion, - pp_optimizer, - return_loss=True, - return_outputs=True) + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret['loss']) + assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients torch_grad = [] @@ -147,7 +145,7 @@ def examine_pp(num_micro_batches): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") examine_pp() @@ -157,5 +155,5 @@ def test_pp(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_pp() diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 41b535573c39..1d77edc2db11 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -16,7 +16,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(4, 8) @@ -28,17 +27,15 @@ def forward(self, x): return x -def pp_linear_fwd(forward, - data: torch.Tensor = None, - input_obj: torch.Tensor = None, - stage_mgr: PipelineStageManager = None): - +def pp_linear_fwd( + forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None +): if stage_mgr.is_first_stage(): - return {'input_obj': forward(data)} + return {"input_obj": forward(data)} elif stage_mgr.is_last_stage(): return forward(input_obj) else: - return {'input_obj': forward(input_obj)} + return {"input_obj": forward(input_obj)} def examine_pp(): @@ -89,16 +86,13 @@ def examine_pp(): torch_loss = criterion(torch_output, _) torch_loss.backward() - pp_ret = schedule.forward_backward_step(sharded_model, - iter(input_list), - criterion, - pp_optimizer, - return_loss=True, - return_outputs=True) + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret['loss']) + assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients torch_grad = [] @@ -120,7 +114,7 @@ def examine_pp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") examine_pp() @@ -130,5 +124,5 @@ def test_pp(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_pp() diff --git a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py index 4c23a23ebaba..462355ee470b 100644 --- a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py +++ b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py @@ -8,9 +8,9 @@ def test_get_batch_size(): assert get_batch_size(tensor) == 2 assert get_batch_size([tensor]) == 2 assert get_batch_size((1, tensor)) == 2 - assert get_batch_size({'tensor': tensor}) == 2 - assert get_batch_size({'dummy': [1], 'tensor': tensor}) == 2 - assert get_batch_size({'tensor': [tensor]}) == 2 + assert get_batch_size({"tensor": tensor}) == 2 + assert get_batch_size({"dummy": [1], "tensor": tensor}) == 2 + assert get_batch_size({"tensor": [tensor]}) == 2 def test_get_micro_batch(): @@ -26,12 +26,12 @@ def test_get_micro_batch(): micro_batch = get_micro_batch([x, y], 1, 1) assert torch.equal(micro_batch[0], x[1:2]) assert torch.equal(micro_batch[1], y[1:2]) - micro_batch = get_micro_batch({'x': x, 'y': y}, 0, 1) - assert torch.equal(micro_batch['x'], x[0:1]) - assert torch.equal(micro_batch['y'], y[0:1]) - micro_batch = get_micro_batch({'x': x, 'y': y}, 1, 1) - assert torch.equal(micro_batch['x'], x[1:2]) - assert torch.equal(micro_batch['y'], y[1:2]) + micro_batch = get_micro_batch({"x": x, "y": y}, 0, 1) + assert torch.equal(micro_batch["x"], x[0:1]) + assert torch.equal(micro_batch["y"], y[0:1]) + micro_batch = get_micro_batch({"x": x, "y": y}, 1, 1) + assert torch.equal(micro_batch["x"], x[1:2]) + assert torch.equal(micro_batch["y"], y[1:2]) def test_merge_batch(): @@ -42,6 +42,6 @@ def test_merge_batch(): merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]]) assert torch.equal(merged[0], x) assert torch.equal(merged[1], y) - merged = merge_batch([{'x': x[0:1], 'y': y[0:1]}, {'x': x[1:2], 'y': y[1:2]}]) - assert torch.equal(merged['x'], x) - assert torch.equal(merged['y'], y) + merged = merge_batch([{"x": x[0:1], "y": y[0:1]}, {"x": x[1:2], "y": y[1:2]}]) + assert torch.equal(merged["x"], x) + assert torch.equal(merged["y"], y) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 6e0cd1998c11..ed8284b3e64c 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -64,7 +64,7 @@ def check_stage_manager(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_stage_manager() @@ -74,5 +74,5 @@ def test_pipeline_stage_manager(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_stage_manager() diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index 72e6e5cf26ed..277a5b2bb4be 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -7,12 +7,14 @@ from colossalai.shardformer.layer import cross_entropy_1d from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) def check_dist_crossentropy(rank, world_size, port, ignore_index): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") # prepare data pred = torch.randn(2, 4, 8, requires_grad=True) @@ -25,10 +27,11 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss = F.cross_entropy(org_pred, org_labels) dist_pred = pred.chunk(world_size, -1)[rank] - dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index) - assert torch.allclose(org_loss, dist_loss, - atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + assert torch.allclose( + org_loss, dist_loss, atol=1e-5 + ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" @pytest.mark.dist @@ -38,5 +41,5 @@ def test_dist_crossentropy(): spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) -if __name__ == '__main__': +if __name__ == "__main__": test_dist_crossentropy() diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index 332e377110a4..576620e6c7f3 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -56,7 +56,7 @@ def check_dropout_replicated_input(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_dropout_parallel_input() check_dropout_replicated_input() @@ -66,5 +66,5 @@ def test_dropout(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index d62dba7ea92a..3dbbcd766bf4 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -11,7 +11,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -43,7 +43,7 @@ def check_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_embedding_1d() @@ -52,5 +52,5 @@ def test_embedding_1d(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 4c0f884a7ed5..10ffdcd7138c 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -58,12 +58,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - seq_parallel=seq_parallel, - n_fused=3, - overlap=overlap) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -97,10 +94,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, - process_group=None, - parallel_input=False, - seq_parallel=seq_parallel) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -128,16 +124,16 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(target_grad, linear_row.weight.grad) -@parameterize('lazy_init', [False, True]) -@parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [True]) +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_row(lazy_init, seq_parallel) def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_gpt2_qkv_fused_linear_1d() @@ -148,5 +144,5 @@ def test_linearconv(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index f9c21b82a282..3eb3bb2e5b8d 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -10,7 +10,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_layernorm(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -41,7 +41,7 @@ def check_layernorm(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_layernorm() @@ -50,5 +50,5 @@ def test_layernorm(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_layernorm() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index e6d86d533ed6..5bacf1865c48 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -17,11 +17,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - seq_parallel=seq_parallel, - overlap=overlap) + linear_col = Linear1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + ) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -60,8 +58,11 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( - x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_unshard_gard, x_for_shard.grad) @@ -71,10 +72,9 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, - process_group=None, - parallel_input=False, - seq_parallel=seq_parallel) + linear_row = Linear1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -121,15 +121,12 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, - process_group=None, - gather_output=False, - seq_parallel=seq_parallel, - overlap=overlap) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, - process_group=None, - parallel_input=True, - seq_parallel=seq_parallel) + linear_col = Linear1D_Col.from_native_module( + linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + ) + linear_row = Linear1D_Row.from_native_module( + linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + ) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -161,14 +158,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( - x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -@parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [True]) +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) def run_dist_linear_test(lazy_init, seq_parallel, overlap): check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel) @@ -176,7 +176,7 @@ def run_dist_linear_test(lazy_init, seq_parallel, overlap): def check_dist_linear(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_dist_linear_test() @@ -185,5 +185,5 @@ def test_linear(): spawn(check_dist_linear, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linear() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index b45cd172c3ca..b02d581810cd 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -53,16 +53,15 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - n_fused=3) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, n_fused=3 + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -89,7 +88,7 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_linear_conv_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -124,7 +123,7 @@ def check_linear_conv_1d_row(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_linear_conv_1d_col() @@ -136,5 +135,5 @@ def test_linearconv(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 6d2f087302d9..b23a44f2dffa 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -11,13 +11,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_vocab_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - embedding = nn.Embedding(128, 32).to('cuda') + embedding = nn.Embedding(128, 32).to("cuda") with ctx: - embedding_copy = nn.Embedding(128, 32).to('cuda') + embedding_copy = nn.Embedding(128, 32).to("cuda") dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) @@ -30,7 +30,7 @@ def check_vocab_embedding_1d(lazy_init: bool): dist_embedding_1d.load_state_dict(embedding.state_dict()) # check embedding correctness - x = torch.randint(0, 128, (4, 32)).to('cuda') + x = torch.randint(0, 128, (4, 32)).to("cuda") org_out = embedding(x) dist_out = dist_embedding_1d(x) assert_close(org_out, dist_out) @@ -45,7 +45,7 @@ def check_vocab_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_vocab_embedding_1d() @@ -54,5 +54,5 @@ def test_vocab_embedding(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_vocab_embedding() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c9c6447a43f0..0a2b151d4274 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -22,13 +22,15 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -def build_model(model_fn, - enable_fused_normalization=True, - enable_tensor_parallelism=True, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - use_lazy_init: bool = False): +def build_model( + model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + use_lazy_init: bool = False, +): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: @@ -38,23 +40,27 @@ def build_model(model_fn, if use_lazy_init: ctx.materialize(org_model) # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() -def build_pipeline_model(model_fn, - stage_manager=None, - enable_fused_normalization=False, - enable_tensor_parallelism=False, - use_lazy_init: bool = False, - policy: Optional[Policy] = None): +def build_pipeline_model( + model_fn, + stage_manager=None, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + use_lazy_init: bool = False, + policy: Optional[Policy] = None, +): ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -64,9 +70,11 @@ def build_pipeline_model(model_fn, ctx.materialize(org_model) # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - pipeline_stage_manager=stage_manager) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager, + ) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy) @@ -91,22 +99,21 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, return org_output, org_loss, shard_output, shard_loss -def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""): org_sd = org_model.state_dict() shard_sd = sharded_model.state_dict() for k, v in org_sd.items(): - assert k in shard_sd, f'{name} {k} not in sharded model' + assert k in shard_sd, f"{name} {k} not in sharded model" shard_v = shard_sd[k] - assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' - assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' - assert torch.equal(v, shard_v), f'{name} {k} value mismatch' + assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}" + assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}" + assert torch.equal(v, shard_v), f"{name} {k} value mismatch" def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]): - use_lazy_init = False - if 'use_lazy_init' in test_config: - use_lazy_init = test_config.pop('use_lazy_init') + if "use_lazy_init" in test_config: + use_lazy_init = test_config.pop("use_lazy_init") ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: @@ -127,9 +134,15 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster -def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, - data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, - booster: Booster): +def run_forward_backward_with_hybrid_plugin( + org_model: Module, + sharded_model: Module, + sharded_optimizer: Optimizer, + data_gen_fn: Callable, + output_transform_fn: Callable, + criterion: Callable, + booster: Booster, +): org_model.cuda() sharded_model.cuda() @@ -141,10 +154,10 @@ def _criterion(outputs, inputs): data = data_gen_fn() if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data['input_ids'].shape[-1] + seq_len = data["input_ids"].shape[-1] lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) times = lcm // seq_len - input_shape = data['input_ids'].shape + input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) @@ -152,19 +165,16 @@ def _criterion(outputs, inputs): sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) data_iter = iter([data]) - sharded_output = booster.execute_pipeline(data_iter, - sharded_model, - _criterion, - sharded_optimizer, - return_loss=True, - return_outputs=True) - sharded_loss = sharded_output['loss'] + sharded_output = booster.execute_pipeline( + data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + ) + sharded_loss = sharded_output["loss"] else: data = {k: v.cuda() for k, v in data.items()} sharded_output = sharded_model(**data) @@ -182,45 +192,49 @@ def _criterion(outputs, inputs): return org_loss, org_output, sharded_loss, sharded_output -def check_output_hidden_state(org_output: Tensor, - sharded_output: Tensor, - stage_manager: Optional[PipelineStageManager] = None, - atol: float = 1e-5, - rtol: float = 1e-3, - dim: int = 0): - +def check_output_hidden_state( + org_output: Tensor, + sharded_output: Tensor, + stage_manager: Optional[PipelineStageManager] = None, + atol: float = 1e-5, + rtol: float = 1e-3, + dim: int = 0, +): org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: sharded_hidden_state = sharded_output.last_hidden_state - assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ - f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + assert torch.allclose( + org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol + ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \ - f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" - - -def check_weight(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: Optional[ProcessGroup] = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False): - + assert torch.allclose( + org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol + ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + + +def check_weight( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: Optional[ProcessGroup] = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): for suffix in layer_suffix: org_weight = getattr_(org_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -228,33 +242,35 @@ def check_weight(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" - - -def get_grad_tensors_for_check(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: ProcessGroup = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False, - name: str = None): - + assert torch.allclose( + org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol + ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + + +def get_grad_tensors_for_check( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None, +): grad_to_check = {} for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[:org_grad.shape[0], :] + shard_grad = shard_grad[: org_grad.shape[0], :] if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -262,33 +278,35 @@ def get_grad_tensors_for_check(org_model: Module, "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), "rtol": rtol, - "atol": atol + "atol": atol, } return grad_to_check # used by sam/blip2 -def check_grad(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: ProcessGroup = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False): +def check_grad( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[:org_grad.shape[0], :] + shard_grad = shard_grad[: org_grad.shape[0], :] if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -297,9 +315,9 @@ def check_grad(org_model: Module, ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" -def unwrap_model(module: Module, - base_model_class_name: Optional[str] = None, - base_model_attribute_name: Optional[str] = None): +def unwrap_model( + module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None +): if isinstance(module, HybridParallelModule): module = module.unwrap() if base_model_class_name is None: diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index c779e417052b..31fd58d06f77 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,52 +20,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - bert = unwrap_model(org_model, 'BertModel', 'bert') - sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - col_layer_for_check = ['encoder.layer[0].output.dense'] - row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - col_layer_grads = get_grad_tensors_for_check(bert, - sharded_bert, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - row_layer_grads = get_grad_tensors_for_check(bert, - sharded_bert, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -76,17 +59,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BertModel': + if org_model.__class__.__name__ == "BertModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -99,53 +82,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': True, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_bert_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -155,31 +141,33 @@ def run_bert_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_bert_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -189,13 +177,13 @@ def run_bert_3d_test(test_config): def check_bert(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_test() def check_bert_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index cd034d0c139a..02c15460ecb3 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -16,16 +16,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["past_key_values"]) # do backward org_loss.backward() shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # check grad @@ -34,26 +36,29 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad col_layer_for_check = [ - 'vision_model.encoder.layers[0].self_attn.qkv', 'qformer.encoder.layer[0].attention.attention.query', - 'language_model.model.decoder.layers[0].self_attn.k_proj' + "vision_model.encoder.layers[0].self_attn.qkv", + "qformer.encoder.layer[0].attention.attention.query", + "language_model.model.decoder.layers[0].self_attn.k_proj", ] row_layer_for_check = [ - 'vision_model.encoder.layers[0].self_attn.projection', 'qformer.encoder.layer[0].attention.output.dense', - 'language_model.model.decoder.layers[0].self_attn.out_proj' + "vision_model.encoder.layers[0].self_attn.projection", + "qformer.encoder.layer[0].attention.output.dense", + "language_model.model.decoder.layers[0].self_attn.out_proj", ] check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) +@parameterize("enable_jit_fused", [True, False]) def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): - sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') + sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused) + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused + ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -61,7 +66,7 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable def check_blip2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_blip2_test() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index c9ee690c86dc..7fe791db6d5e 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -20,53 +20,37 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - bloom = unwrap_model(org_model, 'BloomModel', 'transformer') - sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') + bloom = unwrap_model(org_model, "BloomModel", "transformer") + sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer") - row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] - col_layer_for_check = ['h[0].self_attention.dense'] + row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] + col_layer_for_check = ["h[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-5 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(bloom, - sharded_bloom, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(bloom, - sharded_bloom, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -76,17 +60,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BloomModel': + if org_model.__class__.__name__ == "BloomModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -98,54 +82,51 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_bloom_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -155,29 +136,32 @@ def run_bloom_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_bloom_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -189,13 +173,13 @@ def run_bloom_3d_test(test_config): def check_bloom(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_test() def check_bloom_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 48f651c727f4..bdf5b79fc498 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,54 +20,52 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') - shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') + chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer") + shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") - row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] - col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] + col_layer_for_check = ["encoder.layers[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-3 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - col_layer_grads = get_grad_tensors_for_check(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + + col_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -78,30 +75,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'ChatGLMModel': + if org_model.__class__.__name__ == "ChatGLMModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -110,45 +109,41 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_chatglm_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -157,29 +152,32 @@ def run_chatglm_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_chatglm_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -190,13 +188,13 @@ def run_chatglm_3d_test(test_config): def check_chatglm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_chatglm_test() def check_chatglm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_chatglm_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c4cc3812dbfd..69a15166a54c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,53 +20,37 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') - sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') + gpt2 = unwrap_model(org_model, "GPT2Model", "transformer") + sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer") - col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + col_layer_for_check = ["h[0].mlp.c_fc"] + row_layer_for_check = ["wte", "h[0].mlp.c_proj"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - col_layer_grads = get_grad_tensors_for_check(gpt2, - sharded_gpt2, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - row_layer_grads = get_grad_tensors_for_check(gpt2, - sharded_gpt2, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + col_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,19 +60,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'GPT2Model': + if org_model.__class__.__name__ == "GPT2Model": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -102,63 +85,73 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_gpt2_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -167,30 +160,33 @@ def run_gpt2_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_gpt2_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -201,13 +197,13 @@ def run_gpt2_3d_test(test_config): def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gpt2_test() def check_gpt2_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gpt2_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a60150e3cd72..f8f08e1d0075 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,7 +2,6 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,57 +20,41 @@ unwrap_model, ) -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - llama_model = unwrap_model(org_model, 'LlamaModel', 'model') - shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') + llama_model = unwrap_model(org_model, "LlamaModel", "model") + shard_llama_model = unwrap_model(sharded_model, "LlamaModel", "model") - row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - col_layer_for_check = ['layers[0].self_attn.o_proj'] + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-4 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -81,30 +64,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'LlamaModel': + if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -112,60 +90,64 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_llama_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -175,29 +157,32 @@ def run_llama_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_llama_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -209,13 +194,13 @@ def run_llama_3d_test(test_config): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test() def check_llama_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 3e74859ad1a8..d21ab264d8ab 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -2,7 +2,6 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,57 +20,41 @@ unwrap_model, ) -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - opt_model = unwrap_model(org_model, 'OPTModel', 'model') - shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') + opt_model = unwrap_model(org_model, "OPTModel", "model") + shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model") - row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' - col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens' + col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-3 else: atol, rtol = 4e-2, 4e-2 - row_layer_grads = get_grad_tensors_for_check(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -81,29 +64,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'OPTModel': + if org_model.__class__.__name__ == "OPTModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -112,53 +90,51 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_opt_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -166,29 +142,32 @@ def run_opt_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_opt_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -199,13 +178,13 @@ def run_opt_3d_test(test_config): def check_OPTModel(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_opt_test() def check_opt_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_opt_3d_test() @@ -223,6 +202,6 @@ def test_opt_3d(): spawn(check_opt_3d, 8) -if __name__ == '__main__': +if __name__ == "__main__": test_OPTModel() test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py index 616104cd7828..a8d4cb635221 100644 --- a/tests/test_shardformer/test_model/test_shard_sam.py +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -16,16 +16,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['pred_masks']) + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["pred_masks"]) # do backward org_loss.backward() shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # check grad @@ -33,20 +35,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo sharded_sam = sharded_model # check grad - col_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.q_proj', 'vision_encoder.layers[0].mlp.lin1'] - row_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.out_proj', 'vision_encoder.layers[0].mlp.lin2'] + col_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.q_proj", "vision_encoder.layers[0].mlp.lin1"] + row_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.out_proj", "vision_encoder.layers[0].mlp.lin2"] check_grad(sam, sharded_sam, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) check_grad(sam, sharded_sam, row_layer_for_check, atol=1e-3, rtol=1e-3, dim=1, verbose=False) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): - sub_model_zoo = model_zoo.get_sub_registry('transformers_sam') + sub_model_zoo = model_zoo.get_sub_registry("transformers_sam") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention) + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention + ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -54,7 +57,7 @@ def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_f def check_sam(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_sam_test() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 768cae0a6734..73f203d1f023 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -21,19 +20,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -42,22 +35,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) - row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - row_layer_grads = get_grad_tensors_for_check(t5, - sharded_t5, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0) + row_layer_grads = get_grad_tensors_for_check( + t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) grads_to_check.update(row_layer_grads) # optimizer executes step @@ -66,18 +55,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + if org_model.__class__.__name__ != "T5ForConditionalGeneration": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -90,67 +79,70 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_t5_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # skip 4-stage pp test for t5_encoder - if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model': + if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -160,29 +152,32 @@ def run_t5_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_t5_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -193,13 +188,13 @@ def run_t5_3d_test(test_config): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_t5_test() def check_t5_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_t5_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 15db63bfd9da..1c934bd22340 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -20,54 +20,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - vit_model = unwrap_model(org_model, 'ViTModel', 'vit') - shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') + vit_model = unwrap_model(org_model, "ViTModel", "vit") + shard_vit_model = unwrap_model(sharded_model, "ViTModel", "vit") # check grad - row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] - col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + row_layer_for_check = ["encoder.layer[0].attention.attention.query", "embeddings.patch_embeddings.projection"] + col_layer_for_check = ["encoder.layer[0].attention.output.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,29 +61,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'ViTModel': + if org_model.__class__.__name__ == "ViTModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -107,57 +86,54 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -#TODO: num_microbatch size = 2 inf loss -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +# TODO: num_microbatch size = 2 inf loss +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_vit_test(test_config): - # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -166,28 +142,31 @@ def run_vit_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) def run_vit_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -198,13 +177,13 @@ def run_vit_3d_test(test_config): def check_vit(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_test() def check_vit_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index d0c04c98f80a..f839bd84ab69 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -5,13 +5,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, @@ -26,24 +20,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwarp the model - if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': + if org_model.__class__.__name__ == "WhisperForConditionalGeneration": whisper = org_model.model sharded_whisper = sharded_model.unwrap().model else: @@ -51,41 +40,33 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_whisper = sharded_model.unwrap() # check grad - if org_model.__class__.__name__ == 'WhisperForAudioClassification': - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] + if org_model.__class__.__name__ == "WhisperForAudioClassification": + col_layer_for_check = ["encoder.layers[0].self_attn.q_proj"] + row_layer_for_check = ["encoder.layers[0].self_attn.out_proj"] else: col_layer_for_check = [ - 'encoder.layers[0].self_attn.q_proj', - # 'decoder.layers[0].self_attn.q_proj' + "encoder.layers[0].self_attn.q_proj", + # 'decoder.layers[0].self_attn.q_proj' ] row_layer_for_check = [ - 'encoder.layers[0].self_attn.out_proj', - #'decoder.layers[0].self_attn.out_proj' + "encoder.layers[0].self_attn.out_proj", + #'decoder.layers[0].self_attn.out_proj' ] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - row_layer_grads = get_grad_tensors_for_check(whisper, - sharded_whisper, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1) - col_layer_grads = get_grad_tensors_for_check(whisper, - sharded_whisper, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0) + row_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1 + ) + col_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -95,38 +76,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'WhisperModel': + if org_model.__class__.__name__ == "WhisperModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(whisper, - sharded_whisper, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - check_weight(whisper, - sharded_whisper, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + check_weight( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + check_weight( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -134,49 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -#TODO fix WhisperForConditionalGeneration enable jit fused operato +# TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 @parameterize( - 'test_config', + "test_config", [ { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', - 'initial_scale': 1, + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, }, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, }, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", }, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", }, - # whisper is not supported fp16 for now. - ]) + # whisper is not supported fp16 for now. + ], +) def run_whisper_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - - if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + if test_config["pp_size"] > 2 and name == "transformers_whisper_for_audio_classification": continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -185,28 +156,31 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) def run_whisper_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -217,13 +191,13 @@ def run_whisper_3d_test(test_config): def check_whisper(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_test() def check_whisper_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_3d_test() diff --git a/tests/test_shardformer/test_shard_utils.py b/tests/test_shardformer/test_shard_utils.py index 220b8291c9c6..9739fad86d39 100644 --- a/tests/test_shardformer/test_shard_utils.py +++ b/tests/test_shardformer/test_shard_utils.py @@ -5,7 +5,6 @@ class Net(nn.Module): - def __init__(self) -> None: super().__init__() self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 2b6933246298..f642a9dcada4 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -14,10 +14,9 @@ from tests.kit.model_zoo import model_zoo -@parameterize('lazy_init', [True, False]) +@parameterize("lazy_init", [True, False]) def check_shardformer_with_ddp(lazy_init: bool): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") # create shardformer # ranks: [0, 1, 2, 3] @@ -72,7 +71,7 @@ def check_shardformer_with_ddp(lazy_init: bool): def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_shardformer_with_ddp() diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 4a3199c1c53d..5e969b1aaf98 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -29,10 +29,9 @@ def check_all_gather(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -101,11 +100,9 @@ def check_all_to_all(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - sharding_spec, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -181,7 +178,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -214,5 +211,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index a1ea2946e6e7..6d1640b4f3dc 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -20,10 +20,9 @@ def check_all_gather(process_groups_dict, rank): tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - process_groups_dict, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, process_groups_dict, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -38,10 +37,9 @@ def check_shard(process_groups_dict, rank): tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, - process_groups_dict, - shard_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, process_groups_dict, shard_dim=1, logical_process_axis=1 + ) tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) if rank in (0, 2): @@ -79,11 +77,13 @@ def check_all_to_all(process_groups_dict, rank): tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - process_groups_dict, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + process_groups_dict, + gather_dim=0, + shard_dim=1, + logical_process_axis=0, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -157,5 +157,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 5a1aef79f332..33ae59d01550 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -8,7 +8,6 @@ class TestModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, out_features) @@ -22,9 +21,9 @@ def forward(self, x): def check_dtensor(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_model = TestModel(8, 8).to('cuda') - original_tensor = torch.rand(4, 8).to('cuda') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + test_model = TestModel(8, 8).to("cuda") + original_tensor = torch.rand(4, 8).to("cuda") compare_output = test_model(original_tensor) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) @@ -39,7 +38,7 @@ def check_dtensor(rank, world_size, port): elif rank in (2, 3): assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) @@ -48,7 +47,7 @@ def check_dtensor(rank, world_size, port): elif rank in (2, 3): assert output.equal(compare_output.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) @@ -62,7 +61,7 @@ def check_dtensor(rank, world_size, port): elif rank == 3: assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) @@ -75,7 +74,7 @@ def check_dtensor(rank, world_size, port): elif rank == 3: assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") @rerun_if_address_is_in_use() @@ -84,5 +83,5 @@ def test_dtensor(): spawn(check_dtensor, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor() diff --git a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py index 7fd1c3d90fc4..654a4438479a 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py +++ b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py @@ -26,9 +26,10 @@ def test_dtensor_sharding_spec(): assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0 assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0 - assert sharding_spec_0.spec_diff(sharding_spec_1) == \ - reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0) + assert sharding_spec_0.spec_diff(sharding_spec_1) == reduce( + operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0 + ) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor_sharding_spec() diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5388fd901e09..4e65401bf7b4 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -20,7 +20,7 @@ def check_one_step_transform(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # [[0, 1], # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -34,10 +34,10 @@ def check_one_step_transform(rank, world_size, port): rst_dict = layout_converter.all_gather_transform_layouts(layout) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] @@ -50,13 +50,13 @@ def check_one_step_transform(rank, world_size, port): rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] @@ -69,20 +69,20 @@ def check_one_step_transform(rank, world_size, port): rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] def check_layout_converting(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -102,8 +102,8 @@ def check_layout_converting(rank, world_size, port): transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) # check transform path - transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # check comm action sequence # all-gather(S01) -> S0 @@ -123,18 +123,18 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].logical_process_axis == 1 # checkout chached_spec_pairs_transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) - assert comm_cost['forward'] == comm_cost['backward'] - assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward']) + assert comm_cost["forward"] == comm_cost["backward"] + assert math.floor(comm_cost["total"]) == math.floor(comm_cost["forward"] + comm_cost["backward"]) def check_layout_converting_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -173,5 +173,5 @@ def test_layout_converter(): spawn(check_layout_converting_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_layout_converter() diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index bd71bffccc70..7d6f8979dd0b 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -17,12 +17,13 @@ def check_mix_gather_S0S1(device_mesh, rank): f_target_pair = (f, [0]) b_target_pair = (b, [1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_slice = [4, 2] # (4, 2) + tensor_slice = [4, 2] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [0], 1: [1]} @@ -31,12 +32,14 @@ def check_mix_gather_S0S1(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -48,12 +51,13 @@ def check_two_all_gather_S0S1(device_mesh, rank): dim_partition_dict = {0: [0], 1: [1]} - tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) # DistSpec: # shard_sequence: S0,S1 @@ -61,10 +65,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -75,10 +78,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -95,8 +97,9 @@ def check_mix_gather_S1S0(device_mesh, rank): rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -105,12 +108,14 @@ def check_mix_gather_S1S0(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -120,12 +125,13 @@ def check_two_all_gather_S1S0(device_mesh, rank): tensor_width = 8 tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() - tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -135,10 +141,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -149,10 +154,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -165,7 +169,7 @@ def check_mix_gather_S01R(device_mesh, rank): f_target_pair = (f, [0, 1]) b_target_pair = (b, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + 1, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} # DistSpec: @@ -173,12 +177,14 @@ def check_mix_gather_S01R(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -189,7 +195,7 @@ def check_two_all_gather_S01R(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + rank_stride, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} @@ -199,10 +205,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -214,10 +219,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -231,7 +235,7 @@ def check_mix_gather_RS01(device_mesh, rank): f_target_pair = (f, []) b_target_pair = (b, [0, 1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + 1].contiguous().cuda() dim_partition_dict = {1: [0, 1]} # DistSpec: @@ -239,12 +243,14 @@ def check_mix_gather_RS01(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -255,7 +261,7 @@ def check_two_all_gather_RS01(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + rank_stride].contiguous().cuda() dim_partition_dict = {1: [0, 1]} @@ -265,10 +271,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -280,10 +285,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -292,7 +296,7 @@ def check_two_all_gather_RS01(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 8) assert rank == dist.get_rank() @@ -326,5 +330,5 @@ def test_mix_gather(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_mix_gather() diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 859eef051256..c51797912e6f 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -2,7 +2,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) @@ -16,7 +16,6 @@ def test_one_step_transform(): - dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R @@ -28,16 +27,14 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict = shape_consistency_manager.get_all_all_gather_spec( + sharding_spec, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] @@ -53,19 +50,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} - rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec( + sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] @@ -81,19 +76,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} - rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_shard = shape_consistency_manager.get_all_shard_spec( + sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] @@ -113,10 +106,11 @@ def test_shape_consistency(): sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - sharding_spec_source, sharding_spec_target) + sharding_spec_source, sharding_spec_target + ) - transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # all-gather(S01) -> S0 assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD @@ -134,12 +128,15 @@ def test_shape_consistency(): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][0] == transform_path - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][1] == comm_action_sequence + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + ) + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1] + == comm_action_sequence + ) -if __name__ == '__main__': +if __name__ == "__main__": test_one_step_transform() test_shape_consistency() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index b57952df401f..b2bc84edd87f 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -4,14 +4,14 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -72,5 +72,5 @@ def test_apply(): spawn(check_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_apply() diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 5007c4141849..7730683bf525 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -1,7 +1,7 @@ import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec def test_sharding_spec(): @@ -21,5 +21,5 @@ def test_sharding_spec(): assert str(sharding_spec.sharding_sequence) == "[S01, R, R]" -if __name__ == '__main__': +if __name__ == "__main__": test_sharding_spec() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index f775710c40c2..a5c465ba0b07 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -23,30 +23,30 @@ def attention_ref(q, k, v, attn_mask=None, causal=False): seqlen_q, seqlen_k = q.shape[1], k.shape[1] d = q.shape[-1] scale = 1.0 / math.sqrt(d) - scores = torch.einsum('bthd,bshd->bhts', q * scale, k) + scores = torch.einsum("bthd,bshd->bhts", q * scale, k) if attn_mask is not None: - scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) if causal: causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float('-inf')) + scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - output = torch.einsum('bhts,bshd->bthd', attention, v) + output = torch.einsum("bhts,bshd->bthd", attention, v) output = rearrange(output, "b s h d -> b s (h d)") # Modify the data at the positions of the mask to 0 if attn_mask is not None: - output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0) + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) return output.to(dtype=dtype_og) @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_gpt(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -78,9 +78,9 @@ def test_attention_gpt(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_bert(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -111,9 +111,9 @@ def test_attention_bert(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_no_mask(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -141,9 +141,9 @@ def test_attention_no_mask(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 24, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_cross_attention(proj_shape, dtype, dropout): (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index f05ccfdbd41b..879eeccde3b4 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -12,54 +12,53 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_memory(keep_gathered, pin_memory): - params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)] config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} chunk_manager = ChunkManager(config) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 + assert chunk_manager.total_mem["cpu"] == 0 + assert chunk_manager.total_mem["cuda"] == 0 process_group = _get_default_group() for p in params: - chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory) + chunk_manager.register_tensor(p, "param", 2, process_group, pin_memory=pin_memory) chunk_manager.close_all_groups() - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] chunks = chunk_manager.get_chunks(params) for chunk in chunks: chunk_manager.access_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[True] for chunk in chunks: chunk_manager.release_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + chunk_manager.move_chunk(chunk, torch.device("cpu")) + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_1[keep_gathered] def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_memory() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_chunk_manager(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_manager(2) diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index cc598ee60361..a31c888e966d 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -31,26 +31,28 @@ def check_equal(param, param_cp): return torch.equal(temp, param_cp.data) -@parameterize('init_device', [None, torch.device('cpu')]) -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("init_device", [None, torch.device("cpu")]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = _get_default_group() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - cpu_shard_init=True, - keep_gathered=keep_gathered, - pin_memory=pin_memory) + my_chunk = Chunk( + chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory, + ) param_list = [] param_cp_list = [] - add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 8, 8, 8, device="cuda") add_param(param_list, param_cp_list, 4, 4) - add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 4, 8, 2, device="cuda") add_param(param_list, param_cp_list, 1, 1, 5) for param in param_list: @@ -62,12 +64,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cpu_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cpu' + assert my_chunk.device_type == "cpu" assert my_chunk.can_move my_chunk.shard_move(get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size @@ -75,7 +77,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert not flag, "has_inf_or_nan is {}".format(flag) my_chunk.access_chunk() - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" for param, param_cp in zip(param_list, param_cp_list): check_equal(param, param_cp) @@ -97,25 +99,25 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert my_chunk.can_move else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_basic() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2, 4]) +@pytest.mark.parametrize("world_size", [1, 2, 4]) @rerun_if_address_is_in_use() def test_chunk_function(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_function(4) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index fabdd6072c31..94e70040019c 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -16,21 +16,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -41,14 +30,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): for chunk in chunk_list: chunk_manager.access_chunk(chunk) - for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + for p0, p1 in zip(model.parameters(), torch_model.parameters()): assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'albert']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gather", [False, True]) +@parameterize("model_name", ["gpt2", "bert", "albert"]) +@parameterize("use_grad_checkpoint", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, @@ -69,14 +58,14 @@ def exam_gpt_fwd_bwd( world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) rank = dist.get_rank() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[rank]) @@ -105,16 +94,16 @@ def exam_gpt_fwd_bwd( def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_gpt_fwd_bwd() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt(4) diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 614a96ccdbcd..2fa2d50a6caa 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -14,10 +14,10 @@ # run gemini use the runtime memory tracer -@parameterize('placement_policy', ['auto']) -@parameterize('keep_gather', [False]) -@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_policy", ["auto"]) +@parameterize("keep_gather", [False]) +@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) +@parameterize("use_grad_checkpoint", [False, True]) def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -25,7 +25,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ model = model_builder(use_grad_checkpoint).cuda() - print(f'model_name {model_name}') + print(f"model_name {model_name}") runtime_mem_tracer = RuntimeMemTracer(model) for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: @@ -37,17 +37,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) - print('runtime tracer: ', runtime_tracer_non_model_data) + print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) + print("runtime tracer: ", runtime_tracer_non_model_data) print([memstats.param_used_step(p) for p in model.parameters()]) - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: assert len(step_list) == 4 - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: @@ -55,13 +55,11 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - model = GeminiDDP(model, - chunk_config_dict=config_dict, - placement_policy=placement_policy, - pin_memory=True, - memstats=memstats) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather + model = GeminiDDP( + model, chunk_config_dict=config_dict, placement_policy=placement_policy, pin_memory=True, memstats=memstats + ) set_seed(dist.get_rank()) for i, (input_ids, label) in enumerate(train_dataloader): @@ -73,29 +71,30 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ input_ids, label = input_ids.cuda(), label.cuda() set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, model) + run_fwd_bwd(model, input_ids, label, criterion, model) - gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") # print('gemini non model data:', gemini_non_model_data) - assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ - f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + assert len(gemini_non_model_data) == len( + runtime_tracer_non_model_data + ), f"model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}" def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gemini_use_rmt() @pytest.mark.skip("this is not used") @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gemini_use_rmt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_use_rmt(1) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 860d6efa899a..d8bcc555a15d 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -16,26 +16,24 @@ PLACEMENT_CONFIGS = [ { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0, - 'offload_param_frac': 0.0 - }, # zero2 + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.0, + "offload_param_frac": 0.0, + }, # zero2 { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0, - 'offload_param_frac': 0.0 - }, # zero2-offload + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 0.0, + }, # zero2-offload { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5, - 'offload_param_frac': 0.0 - }, # zero2-offload-half - { - 'placement_policy': 'auto' - } + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.5, + "offload_param_frac": 0.0, + }, # zero2-offload-half + {"placement_policy": "auto"}, ] @@ -52,15 +50,15 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) def exam_grad_clipping(placement_config, model_name: str): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -72,18 +70,16 @@ def exam_grad_clipping(placement_config, model_name: str): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_config['placement_policy'] != 'cuda': - init_device = torch.device('cpu') + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False + if placement_config["placement_policy"] != "cuda": + init_device = torch.device("cpu") else: init_device = None - model = GeminiDDP(model, - chunk_config_dict=config_dict, - chunk_init_device=init_device, - pin_memory=True, - **placement_config) + model = GeminiDDP( + model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) @@ -106,6 +102,7 @@ def exam_grad_clipping(placement_config, model_name: str): assert_close(torch_loss, loss) import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) torch_optim.step() zero_optim.step() @@ -115,16 +112,16 @@ def exam_grad_clipping(placement_config, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_grad_clipping() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_grad_clip(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_clip(2) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 99ee08c1d7e7..2b2b246a9f54 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -18,21 +18,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -52,8 +41,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): def multi_chunk_init(model: torch.nn.Module, placement_config: dict): world_size = dist.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config) return model @@ -63,16 +52,16 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): return model -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2']) -@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) +@parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -121,16 +110,16 @@ def inference_iter(): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_inference() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_inference(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_inference(1) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 3454959199d2..b7c08392600f 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -16,50 +16,30 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0, - 'offload_optim_frac': 1.0, - 'offload_param_frac': 1.0 - }, # zero3-offload-all - { - 'placement_policy': 'auto' - } + "placement_policy": "static", + "shard_param_frac": 1.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 1.0, + }, # zero3-offload-all + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks -TEST_MODELS = ['gpt2'] +TEST_MODELS = ["gpt2"] # these models are too small, all parameters in these models are compacted into one chunk -EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] # bfloat16 cannot represent them exactly BF16_IGNORED_KEYS = [ - 'albert.embeddings.word_embeddings.weight', - 'albert.embeddings.position_embeddings.weight', - 'masked_bias', + "albert.embeddings.word_embeddings.weight", + "albert.embeddings.position_embeddings.weight", + "masked_bias", ] @@ -78,23 +58,25 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty if dtype is torch.bfloat16: rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value.float(), - temp_zero_value.float(), - rtol=rtol, - atol=atol, - msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') - - -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', TEST_MODELS) -@parameterize('mixed_precision', [torch.half, torch.bfloat16]) + assert_close( + value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f"\n{key}\n{temp_zero_value.dtype}", + ) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", TEST_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -106,8 +88,8 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) @@ -135,16 +117,16 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt check_param(model, torch_model, mixed_precision) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', EXAMPLE_MODELS) -@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", EXAMPLE_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -154,12 +136,14 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - model = GeminiDDP(model, - chunk_init_device=get_current_device(), - search_range_m=1, - pin_memory=True, - mixed_precision=mixed_precision, - **placement_config) + model = GeminiDDP( + model, + chunk_init_device=get_current_device(), + search_range_m=1, + pin_memory=True, + mixed_precision=mixed_precision, + **placement_config, + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2) @@ -182,7 +166,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() @@ -192,17 +176,17 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_model_step() exam_tiny_example() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_optim(1) diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 29bd61390523..8e0f6ae36c46 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -13,7 +13,7 @@ @pytest.mark.skip("this is not used") @clear_cache_before_run() def test_runtime_mem_tracer(): - test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -35,7 +35,7 @@ def test_runtime_mem_tracer(): for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) - non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list("cuda") cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) print(non_model_data_list) @@ -46,9 +46,9 @@ def test_runtime_mem_tracer(): cnt2 = 0 for p in model.parameters(): cnt2 += 1 - assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + assert cnt2 == cnt1, f"visited param number {cnt1} vs real param number {cnt2}" del model -if __name__ == '__main__': +if __name__ == "__main__": test_runtime_mem_tracer() diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index 4c7f2ee6c132..e22e5ece42a5 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -11,19 +11,17 @@ def exam_search_chunk_size(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() # make sure torch_model and model has the same parameter values model = model_builder() - config_dict, *_ = search_chunk_configuration(model, - search_range_m=1, - search_interval=16, - min_chunk_size_m=0, - filter_exlarge_params=True) + config_dict, *_ = search_chunk_configuration( + model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True + ) for key in config_dict: - chunk_size = config_dict[key]['chunk_size'] + chunk_size = config_dict[key]["chunk_size"] if world_size == 1 or True: assert chunk_size == 31616 else: @@ -33,34 +31,36 @@ def exam_search_chunk_size(): def exam_chunk_manager(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() sharded_ddp_model = model_builder() - chunk_manager = init_chunk_manager(sharded_ddp_model, - get_current_device(), - hidden_dim=16, - search_range_m=1, - min_chunk_size_m=0, - filter_exlarge_params=True, - strict_ddp_flag=True) + chunk_manager = init_chunk_manager( + sharded_ddp_model, + get_current_device(), + hidden_dim=16, + search_range_m=1, + min_chunk_size_m=0, + filter_exlarge_params=True, + strict_ddp_flag=True, + ) config_dict = chunk_manager.dp_degree_chunk_size_dict assert len(config_dict) == 1 assert config_dict[world_size] == 31616 def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_search_chunk_size() exam_chunk_manager() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_search(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_search(4) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 602e3ad3519d..3130440bd925 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -10,21 +10,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -35,9 +24,9 @@ def ignore_the_first_parameter(model: torch.nn.Module): return -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) def exam_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -51,8 +40,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model.train() @@ -65,9 +54,9 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) def exam_load_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -76,12 +65,12 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): model = model_builder() set_seed(451) - torch_model = model_builder() # get a different model + torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) @@ -95,8 +84,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2", "bert"]) def exam_state_dict_shard(placement_config, model_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -122,18 +111,18 @@ def exam_state_dict_shard(placement_config, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_load_state_dict() exam_state_dict_shard() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_ddp(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_ddp(1) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 5f7b51510d58..8aa656b74cf9 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -11,32 +11,18 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "auto"}, ] -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() @@ -45,13 +31,13 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) - optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 set_seed(dist.get_rank() * 3 + 128) model.train() @@ -67,8 +53,8 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): optim_state_dict = optim.state_dict() optim.load_state_dict(optim_state_dict) - new_state = optim.state_dict()['state'] - org_state = optim_state_dict['state'] + new_state = optim.state_dict()["state"] + org_state = optim_state_dict["state"] for k, v in org_state.items(): w = new_state[k] @@ -82,16 +68,16 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_zero_optim_state_dict() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_optim(1) diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index f170f7cb83da..3c5baea138e0 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -14,7 +14,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) @@ -36,16 +35,12 @@ def exam_zero_1_2_grad_acc(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=32, - clip_grad_norm=1.0, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=32, - clip_grad_norm=1.0) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0 + ) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() @@ -91,10 +86,9 @@ def exam_zero_1_grad_acc(sync): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=False, - reduce_bucket_size=262144, - clip_grad_norm=1.0) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0 + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -104,7 +98,6 @@ def exam_zero_1_grad_acc(sync): input_data2 = torch.randn(32, 128).cuda() def fwd_bwd_func(no_sync, cur_data, check_flag): - # zero1 fwd and bwd with conditional_context(zero_optimizer.no_sync(), no_sync): zero_output = zero_model(cur_data) @@ -135,7 +128,7 @@ def fwd_bwd_func(no_sync, cur_data, check_flag): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_grad_acc(sync=True) exam_zero_1_grad_acc(sync=False) @@ -147,5 +140,5 @@ def test_grad_accumulation(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_accumulation() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 9c4474aff5c3..ebda9f6f25c5 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -14,7 +13,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(123, 253) @@ -74,14 +72,12 @@ def exam_zero_1_2(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=128, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=128) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 + ) # create data seed_all(2001 + local_rank) input_data = torch.randn(32, 123).cuda() @@ -109,7 +105,7 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -@parameterize('dtype', [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.float16, torch.bfloat16]) def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. @@ -134,10 +130,9 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=1024 * 1024) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -178,7 +173,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2() @@ -190,5 +185,5 @@ def test_zero_1_2(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index ab811c6b4d3c..e9fc8598a62d 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -2,19 +2,17 @@ import pytest import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(12, 24) @@ -61,10 +59,9 @@ def exam_zero_1_torch_ddp_ckpt(): # we only test stage 1 here # the state dicts of stage 1 and stage 2 are the same - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=262144) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -88,7 +85,7 @@ def exam_zero_1_torch_ddp_ckpt(): zero_state_dict = zero_optimizer.state_dict() # examine the original state dict - for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): loose_close(t_v, z_v) @@ -100,13 +97,13 @@ def exam_zero_1_torch_ddp_ckpt(): zero_state_dict = zero_optimizer.state_dict() # examine the loaded state dict - for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): loose_close(t_v, z_v) def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp_ckpt() @@ -117,5 +114,5 @@ def test_zero_ckpt(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_ckpt() From 10513f203c912548439c847060f5b1de569cf15f Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 19 Sep 2023 15:28:01 +0800 Subject: [PATCH 25/58] [doc] explain suitable use case for each plugin --- docs/source/en/basics/booster_plugins.md | 52 +++++++++++-------- docs/source/zh-Hans/basics/booster_plugins.md | 49 +++++++++-------- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index d7532b0ce39b..73af15ad2a89 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -1,6 +1,6 @@ # Booster Plugins -Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003) +Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011) **Prerequisite:** - [Booster API](./booster_api.md) @@ -11,16 +11,43 @@ As mentioned in [Booster API](./booster_api.md), we can use booster plugins to c We currently provide the following plugins: -- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. -- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. - [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism. - [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp. +- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. +- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. - [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. More plugins are coming soon. +## Choosing Your Plugin + +Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. + +- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters. +- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters. +- [Gemini Plugin](#gemini-plugin): it is suitable for models with more than 10 billion parameters and is ideal for scenarios with high cross-node bandwidth and medium to small-scale clusters (below a thousand cards). +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, exceptionally long sequences, very large vocabularies, and is best suited for scenarios with low cross-node bandwidth and large-scale clusters (a thousand cards or more). + ## Plugins +### Torch DDP Plugin + +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP Plugin + +> ⚠ This plugin is not available when torch version is lower than 1.12.0. + +> ⚠ This plugin does not support save/load sharded model checkpoint now. + +> ⚠ This plugin does not support optimizer that use multi params group. + +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + ### Low Level Zero Plugin This plugin implements Zero-1 and Zero-2 (w/wo CPU offload), using `reduce` and `gather` to synchronize gradients and weights. @@ -50,24 +77,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} -### Torch DDP Plugin - -More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). - -{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} - -### Torch FSDP Plugin - -> ⚠ This plugin is not available when torch version is lower than 1.12.0. - -> ⚠ This plugin does not support save/load sharded model checkpoint now. - -> ⚠ This plugin does not support optimizer that use multi params group. - -More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). - -{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} - ### Hybrid Parallel Plugin @@ -87,5 +96,4 @@ This plugin implements the combination of various parallel training strategies a {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} - diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 0ad1cacab151..1da645030d26 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -11,16 +11,41 @@ 我们现在提供以下插件: -- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。 -- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。 - [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。 - [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。 +- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。 +- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。 - [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。 更多插件即将推出。 +## 插件选择 +- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型。 +- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型。 +- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型,且跨节点带宽高、中小规模集群(千卡以下)的场景。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且跨节点带宽低、大规模集群(千卡以上)的场景。 + ## 插件 +### Torch DDP 插件 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP 插件 + +> ⚠ 如果 torch 版本低于 1.12.0,此插件将不可用。 + +> ⚠ 该插件现在还不支持保存/加载分片的模型 checkpoint。 + +> ⚠ 该插件现在还不支持使用了multi params group的optimizer。 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + + ### Low Level Zero 插件 该插件实现了 Zero-1 和 Zero-2(使用/不使用 CPU 卸载),使用`reduce`和`gather`来同步梯度和权重。 @@ -50,26 +75,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} - -### Torch DDP 插件 - -更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). - -{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} - -### Torch FSDP 插件 - -> ⚠ 如果 torch 版本低于 1.12.0,此插件将不可用。 - -> ⚠ 该插件现在还不支持保存/加载分片的模型 checkpoint。 - -> ⚠ 该插件现在还不支持使用了multi params group的optimizer。 - -更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). - -{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} - - ### Hybrid Parallel 插件 这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分: From a04337bfc30a81f3d6bc687d933edb92a124228c Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 19 Sep 2023 16:27:37 +0800 Subject: [PATCH 26/58] [doc] put individual plugin explanation in front --- docs/source/en/basics/booster_plugins.md | 18 +++++++++--------- docs/source/zh-Hans/basics/booster_plugins.md | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 73af15ad2a89..7f4f7a859038 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -19,15 +19,6 @@ We currently provide the following plugins: More plugins are coming soon. -## Choosing Your Plugin - -Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. - -- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters. -- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters. -- [Gemini Plugin](#gemini-plugin): it is suitable for models with more than 10 billion parameters and is ideal for scenarios with high cross-node bandwidth and medium to small-scale clusters (below a thousand cards). -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, exceptionally long sequences, very large vocabularies, and is best suited for scenarios with low cross-node bandwidth and large-scale clusters (a thousand cards or more). - ## Plugins ### Torch DDP Plugin @@ -96,4 +87,13 @@ This plugin implements the combination of various parallel training strategies a {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} +## Choosing Your Plugin + +Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. + +- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters. +- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters. +- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters and is ideal for scenarios with high cross-node bandwidth and medium to small-scale clusters (below a thousand cards). +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with low cross-node bandwidth and large-scale clusters (a thousand cards or more). + diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 1da645030d26..b83638c4b801 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -19,12 +19,6 @@ 更多插件即将推出。 -## 插件选择 -- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型。 -- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型。 -- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型,且跨节点带宽高、中小规模集群(千卡以下)的场景。 -- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且跨节点带宽低、大规模集群(千卡以上)的场景。 - ## 插件 ### Torch DDP 插件 @@ -93,4 +87,10 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} +## 插件选择 +- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型。 +- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型。 +- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型,且跨节点带宽高、中小规模集群(千卡以下)的场景。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且跨节点带宽低、大规模集群(千卡以上)的场景。 + From e10d9f087e89c62fea223bd81283f13107b66c3f Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 19 Sep 2023 18:01:23 +0800 Subject: [PATCH 27/58] [doc] add model examples for each plugin --- docs/source/en/basics/booster_plugins.md | 8 ++++---- docs/source/zh-Hans/basics/booster_plugins.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 7f4f7a859038..075b17a1b8d9 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -91,9 +91,9 @@ This plugin implements the combination of various parallel training strategies a Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. -- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters. -- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters. -- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters and is ideal for scenarios with high cross-node bandwidth and medium to small-scale clusters (below a thousand cards). -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with low cross-node bandwidth and large-scale clusters (a thousand cards or more). +- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b). +- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b). +- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b). +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b). diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index b83638c4b801..0857f44e1b06 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -88,9 +88,9 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} ## 插件选择 -- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型。 -- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型。 -- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型,且跨节点带宽高、中小规模集群(千卡以下)的场景。 -- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且跨节点带宽低、大规模集群(千卡以上)的场景。 +- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。 +- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b)。 +- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b),且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b)。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。 From 4d7537ba254dc0b82aaad735d6760065feefe1df Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Wed, 20 Sep 2023 09:24:10 +0800 Subject: [PATCH 28/58] [doc] put native colossalai plugins first in description section --- docs/source/en/basics/booster_plugins.md | 44 ++++++++-------- docs/source/zh-Hans/basics/booster_plugins.md | 50 +++++++++---------- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 075b17a1b8d9..57fa813436da 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -19,25 +19,16 @@ We currently provide the following plugins: More plugins are coming soon. -## Plugins - -### Torch DDP Plugin - -More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). - -{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} - -### Torch FSDP Plugin - -> ⚠ This plugin is not available when torch version is lower than 1.12.0. - -> ⚠ This plugin does not support save/load sharded model checkpoint now. +## Choosing Your Plugin -> ⚠ This plugin does not support optimizer that use multi params group. +Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. -More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). +- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b). +- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b). +- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b). +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b). -{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} +## Plugins ### Low Level Zero Plugin @@ -87,13 +78,22 @@ This plugin implements the combination of various parallel training strategies a {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} -## Choosing Your Plugin +### Torch DDP Plugin -Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). -- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b). -- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b). -- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b). -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b). +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP Plugin + +> ⚠ This plugin is not available when torch version is lower than 1.12.0. + +> ⚠ This plugin does not support save/load sharded model checkpoint now. + +> ⚠ This plugin does not support optimizer that use multi params group. + +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 0857f44e1b06..d4ef7012ff67 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -1,6 +1,7 @@ # Booster 插件 -作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003) +作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011) + **前置教程:** - [Booster API](./booster_api.md) @@ -19,26 +20,13 @@ 更多插件即将推出。 -## 插件 - -### Torch DDP 插件 - -更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). - -{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} - -### Torch FSDP 插件 - -> ⚠ 如果 torch 版本低于 1.12.0,此插件将不可用。 - -> ⚠ 该插件现在还不支持保存/加载分片的模型 checkpoint。 - -> ⚠ 该插件现在还不支持使用了multi params group的optimizer。 - -更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). - -{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} +## 插件选择 +- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。 +- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b)。 +- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b),且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b)。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。 +## 插件 ### Low Level Zero 插件 @@ -87,10 +75,22 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} -## 插件选择 -- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。 -- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b)。 -- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b),且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b)。 -- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。 +### Torch DDP 插件 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP 插件 + +> ⚠ 如果 torch 版本低于 1.12.0,此插件将不可用。 + +> ⚠ 该插件现在还不支持保存/加载分片的模型 checkpoint。 + +> ⚠ 该插件现在还不支持使用了multi params group的optimizer。 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} From 7b9b86441fbffdd07021f234ec88d0dbc470fa5c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 20 Sep 2023 15:53:58 +0800 Subject: [PATCH 29/58] [chat]: update rm, add wandb and fix bugs (#4471) * feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com> --- .github/workflows/run_chatgpt_examples.yml | 2 +- .../benchmarks/benchmark_opt_lora_dummy.py | 4 +- .../Chat/coati/dataset/sft_dataset.py | 18 +++- .../Chat/coati/experience_buffer/naive.py | 3 + .../Chat/coati/experience_maker/base.py | 10 +- .../Chat/coati/experience_maker/naive.py | 27 +++-- applications/Chat/coati/models/base/actor.py | 2 +- applications/Chat/coati/models/base/critic.py | 37 ++----- .../Chat/coati/models/base/reward_model.py | 11 +- applications/Chat/coati/models/generation.py | 16 ++- applications/Chat/coati/models/loss.py | 1 + applications/Chat/coati/models/utils.py | 5 +- applications/Chat/coati/ray/utils.py | 8 +- applications/Chat/coati/trainer/base.py | 24 ++--- .../Chat/coati/trainer/callbacks/base.py | 2 +- .../callbacks/performance_evaluator.py | 2 +- applications/Chat/coati/trainer/ppo.py | 100 +++++++++++------- applications/Chat/coati/trainer/rm.py | 72 ++++++++----- applications/Chat/coati/trainer/sft.py | 78 +++++++------- .../coati/trainer/strategies/colossalai.py | 31 ++---- applications/Chat/examples/inference.py | 10 +- applications/Chat/examples/requirements.txt | 2 +- applications/Chat/examples/train_prompts.py | 56 ++++++---- .../Chat/examples/train_reward_model.py | 49 +++------ applications/Chat/examples/train_rm.sh | 9 +- applications/Chat/examples/train_sft.py | 29 +++-- applications/Chat/examples/train_sft.sh | 1 - applications/Chat/requirements-test.txt | 2 +- applications/Chat/requirements.txt | 3 +- applications/Chat/tests/test_checkpoint.py | 6 +- applications/Chat/tests/test_dataset.py | 4 +- applications/Chat/tests/test_experience.py | 34 ++++-- applications/Chat/tests/test_models.py | 27 ++--- applications/Chat/tests/test_train.sh | 27 +++-- .../auto_offload/train_gpt_offload.py | 2 +- .../pipeline_parallel/train_gpt_pp.py | 2 +- 36 files changed, 383 insertions(+), 333 deletions(-) diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index a336526897e2..f9e9f400962e 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -49,5 +49,5 @@ jobs: NCCL_SHM_DISABLE: 1 MAX_JOBS: 8 SFT_DATASET: /data/scratch/github_actions/chat/data.json - PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl + PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 04f779821405..bee5c8d3faf3 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -138,6 +138,7 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) @@ -154,6 +155,7 @@ def main(args): initial_model, actor_optim, critic_optim, + tokenizer=tokenizer, ptx_coef=0, train_batch_size=args.train_batch_size, offload_inference_models=args.offload_inference_models, @@ -162,8 +164,6 @@ def main(args): temperature=1.0, top_k=50, use_cache=True, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, callbacks=[performance_evaluator], ) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index d6be09ca5cc9..c0e257f54a07 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. import copy -from typing import Dict, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple import torch from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer @@ -57,6 +57,7 @@ def _preprocess( sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) + assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" labels = copy.deepcopy(sequences_token["input_ids"]) for i in range(labels.shape[0]): source_len = sources_token["attention_mask"][i].sum().item() @@ -64,9 +65,10 @@ def _preprocess( if tokenizer.padding_side == "right": # |prompt|completion|eos|pad| labels[i][:source_len] = IGNORE_INDEX + labels[i][-pad_len:] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| - labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX + labels[i][: pad_len + source_len] = IGNORE_INDEX else: raise RuntimeError() @@ -126,6 +128,8 @@ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: in sources = [data["prompt"] for data in dataset] targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] + + logger.info("Tokenizing inputs... This may take some time...") if isinstance(tokenizer, ChatGLMTokenizer): self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( sources, targets, tokenizer, max_length @@ -133,6 +137,8 @@ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: in else: self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) + logger.info("Loaded dataset.") + def __len__(self): length = self.input_ids.shape[0] return length @@ -148,7 +154,11 @@ class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( - self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512 + self, + data_path: str, + tokenizer: PreTrainedTokenizer, + max_datasets_size: Optional[int] = None, + max_length: int = 512, ): super().__init__() logger.info("Loading data...") @@ -175,6 +185,8 @@ def __init__( else: self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) + logger.info("Loaded dataset.") + def __len__(self): length = self.input_ids.shape[0] return length diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py index acc0fbe88ab4..d47b67dbe713 100644 --- a/applications/Chat/coati/experience_buffer/naive.py +++ b/applications/Chat/coati/experience_buffer/naive.py @@ -1,4 +1,5 @@ import random +import warnings from typing import List import torch @@ -30,9 +31,11 @@ def append(self, experience: Experience) -> None: experience.to_device(torch.device("cpu")) items = split_experience_batch(experience) self.items.extend(items) + if self.limit > 0: samples_to_remove = len(self.items) - self.limit if samples_to_remove > 0: + warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.") self.items = self.items[samples_to_remove:] def clear(self) -> None: diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index 727f0a4a52e8..0731f6e0f97f 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -3,8 +3,7 @@ from typing import Optional import torch -import torch.nn as nn -from coati.models.base import Actor +from coati.models.base import Actor, Critic, RewardModel @dataclass @@ -59,16 +58,13 @@ def pin_memory(self): class ExperienceMaker(ABC): - def __init__( - self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1 - ) -> None: + def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None: super().__init__() self.actor = actor self.critic = critic self.reward_model = reward_model self.initial_model = initial_model - self.kl_coef = kl_coef @abstractmethod - def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience: pass diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 30dfd8e0b9bc..941e1994b148 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F +from coati.models.base import Actor, Critic, RewardModel from coati.models.generation import generate from coati.models.utils import calc_action_log_probs, compute_reward +from transformers import PreTrainedTokenizer from .base import Experience, ExperienceMaker @@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker): Naive experience maker. """ + def __init__( + self, + actor: Actor, + critic: Critic, + reward_model: RewardModel, + initial_model: Actor, + tokenizer: PreTrainedTokenizer, + kl_coef: float = 0.1, + ) -> None: + super().__init__(actor, critic, reward_model, initial_model) + self.tokenizer = tokenizer + self.kl_coef = kl_coef + @torch.no_grad() def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: self.actor.eval() @@ -19,16 +34,16 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie self.reward_model.eval() # generate sequences - sequences = generate(self.actor, input_ids, **generate_kwargs) + sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs) # calculate auxiliary tensors attention_mask = None - pad_token_id = generate_kwargs.get("pad_token_id", None) + pad_token_id = self.tokenizer.pad_token_id if pad_token_id is not None: attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) input_len = input_ids.size(1) - eos_token_id = generate_kwargs.get("eos_token_id", None) + eos_token_id = self.tokenizer.eos_token_id if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: @@ -40,11 +55,11 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie action_mask = action_mask[:, -(sequences.size(1) - input_len) :] num_actions = action_mask.size(1) - actor_output = self.actor(sequences, attention_mask) + actor_output = self.actor(sequences, attention_mask)["logits"] action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) - base_model_output = self.initial_model(sequences, attention_mask) + base_model_output = self.initial_model(sequences, attention_mask)["logits"] base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) - value = self.critic(sequences, action_mask, attention_mask) + value = self.critic(sequences, attention_mask) r = self.reward_model(sequences, attention_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 979f9318be50..0634631df7a3 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -25,7 +25,7 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, # HACK: `generate` method may pass more kwargs + **model_kwargs, ) -> torch.Tensor: """Returns model output.""" output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py index 54ab7fa47d48..8672365f5783 100644 --- a/applications/Chat/coati/models/base/critic.py +++ b/applications/Chat/coati/models/base/critic.py @@ -1,10 +1,7 @@ -from typing import Optional - import torch import torch.nn as nn from ..lora import LoRAModule -from ..utils import masked_mean class Critic(LoRAModule): @@ -19,37 +16,19 @@ class Critic(LoRAModule): """ def __init__( - self, - model: nn.Module, - value_head: nn.Module, - lora_rank: int = 0, - lora_train_bias: str = "none", - use_action_mask: bool = False, + self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none" ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head - self.use_action_mask = use_action_mask self.convert_to_lora() - def forward( - self, - sequences: torch.LongTensor, - action_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs["last_hidden_state"] - - values = self.value_head(last_hidden_states).squeeze(-1) - - if action_mask is not None and self.use_action_mask: - num_actions = action_mask.size(1) - prompt_mask = attention_mask[:, :-num_actions] - values = values[:, :-num_actions] - value = masked_mean(values, prompt_mask, dim=1) - return value - - values = values[:, :-1] - value = values.mean(dim=1) - return value + sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ + 0 + ] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] + values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) + return values diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py index 1a70c6cc12bb..e9545d1cddaf 100644 --- a/applications/Chat/coati/models/base/reward_model.py +++ b/applications/Chat/coati/models/base/reward_model.py @@ -35,9 +35,12 @@ def __init__( else: self.value_head = nn.Linear(model.config.n_embd, 1) - def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs["last_hidden_state"] - values = self.value_head(last_hidden_states)[:, :-1] - value = values.mean(dim=1).squeeze(1) # ensure shape is (B) - return value + sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ + 0 + ] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] + values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) + return values diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index e3afac88c7a7..4ab0cdc8a3ea 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist +from transformers import PreTrainedTokenizer from .base import Actor @@ -63,8 +64,8 @@ def _sample( ) outputs = model(**model_inputs) + # NOTE: this is correct only in left padding mode next_token_logits = outputs["logits"][:, -1, :] - # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) @@ -72,8 +73,7 @@ def _sample( # finished sentences should have their next token be a padding token if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs for next step @@ -96,12 +96,11 @@ def _sample( def generate( model: Actor, input_ids: torch.Tensor, + tokenizer: PreTrainedTokenizer, max_length: int, num_beams: int = 1, do_sample: bool = True, early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None, @@ -118,14 +117,13 @@ def generate( num_beams (int, optional): number of beams. Defaults to 1. do_sample (bool, optional): whether to do sample. Defaults to True. early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. - eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None. - pad_token_id (Optional[int], optional): pad token id. Defaults to None. top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. """ + assert tokenizer.padding_side == "left", "Current generation only supports left padding." is_greedy_gen_mode = (num_beams == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and do_sample is False @@ -139,8 +137,8 @@ def generate( input_ids, max_length, early_stopping=early_stopping, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, top_k=top_k, top_p=top_p, temperature=temperature, diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 4ad4f4dcd275..687bd0f7bfe7 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module): def __init__(self): super().__init__() + # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py self.loss = nn.CrossEntropyLoss() def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index def6190dd71c..1aaef16620d2 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch. return log_probs_labels.squeeze(-1) -def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: """Calculate action log probs. Args: - output (torch.Tensor): Output tensor of Actor.forward. + output (torch.Tensor): Output tensor of Actor.forward.logits. sequences (torch.LongTensor): Input sequences. num_actions (int): Number of actions. Returns: torch.Tensor: Action log probs. """ - logits = output["logits"] log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 036dd145dddb..799b2af8f982 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -41,13 +41,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): if model == "gpt2": - critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) + critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) elif model == "bloom": - critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) + critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) elif model == "opt": - critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) + critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) elif model == "llama": - critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) + critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) else: raise ValueError(f'Unsupported reward model "{model}"') return critic diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index ca450edee0c3..0a41d450d41e 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -7,11 +7,10 @@ from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience from torch.optim import Optimizer -from torch.utils.data import DataLoader from .callbacks import Callback from .strategies import Strategy -from .utils import CycledDataLoader, is_rank_0 +from .utils import is_rank_0 class SLTrainer(ABC): @@ -47,11 +46,11 @@ def _eval(self, epoch): raise NotImplementedError() def _before_fit(self): - self.no_epoch_bar = False + raise NotImplementedError() def fit(self, *args, **kwargs): self._before_fit(*args, **kwargs) - for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar): + for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()): self._train(epoch) self._eval(epoch) @@ -123,9 +122,9 @@ def _on_learn_batch_start(self) -> None: for callback in self.callbacks: callback.on_learn_batch_start() - def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def _on_learn_batch_end(self, experience: Experience) -> None: for callback in self.callbacks: - callback.on_learn_batch_end(metrics, experience) + callback.on_learn_batch_end(experience) @abstractmethod def _make_experience(self, collect_step: int): @@ -153,27 +152,26 @@ def _update_phase(self, update_step: int): self._learn(update_step) self._on_learn_epoch_end(update_step) + def _before_fit(self, *args, **kwargs): + raise NotImplementedError() + def fit( self, - prompt_dataloader: DataLoader, - pretrain_dataloader: DataLoader, num_episodes: int, num_collect_steps: int, num_update_steps: int, + *args, + **kwargs, ): """ The main training loop of on-policy rl trainers. Args: - prompt_dataloader (DataLoader): the dataloader to use for prompt data - pretrain_dataloader (DataLoader): the dataloader to use for pretrain data num_episodes (int): the number of episodes to train num_collect_steps (int): the number of collect steps per episode num_update_steps (int): the number of update steps per episode """ - self.prompt_dataloader = CycledDataLoader(prompt_dataloader) - self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) - + self._before_fit(*args, **kwargs) with self._fit_ctx(): for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()): with self._episode_ctx(episode): diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py index d5181175b324..c6e30f04885c 100644 --- a/applications/Chat/coati/trainer/callbacks/base.py +++ b/applications/Chat/coati/trainer/callbacks/base.py @@ -35,5 +35,5 @@ def on_learn_epoch_end(self, epoch: int) -> None: def on_learn_batch_start(self) -> None: pass - def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def on_learn_batch_end(self, experience: Experience) -> None: pass diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index c2eda92cc165..b286c766c263 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -137,7 +137,7 @@ def on_learn_batch_start(self) -> None: return self.learn_timer.start() - def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def on_learn_batch_end(self, experience: Experience) -> None: if self.disable: return self.learn_timer.end() diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 6f255a935d91..d6966689885e 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -1,27 +1,26 @@ -from typing import Dict, List +from typing import Dict, List, Optional -import torch.nn as nn from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models.base import Actor, Critic, get_base_model +from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.models.utils import calc_action_log_probs -from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm +from transformers import PreTrainedTokenizerBase from colossalai.utils import get_current_device from .base import OnPolicyTrainer from .callbacks import Callback from .strategies import GeminiStrategy, Strategy -from .utils import is_rank_0, to_device +from .utils import CycledDataLoader, is_rank_0, to_device def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: - unwrapper_model = strategy.unwrap_model(actor) - hf_model = get_base_model(unwrapper_model) + unwrapped_model = strategy.unwrap_model(actor) + hf_model = get_base_model(unwrapped_model) new_kwargs = {**generate_kwargs} # use huggingface models method directly if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): @@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer): strategy (Strategy): the strategy to use for training actor (Actor): the actor model in ppo algorithm critic (Critic): the critic model in ppo algorithm - reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences + reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor actor_optim (Optimizer): the optimizer to use for actor model critic_optim (Optimizer): the optimizer to use for critic model @@ -65,10 +64,11 @@ def __init__( strategy: Strategy, actor: Actor, critic: Critic, - reward_model: nn.Module, + reward_model: RewardModel, initial_model: Actor, actor_optim: Optimizer, critic_optim: Optimizer, + tokenizer: PreTrainedTokenizerBase, kl_coef: float = 0.1, ptx_coef: float = 0.9, train_batch_size: int = 8, @@ -90,11 +90,11 @@ def __init__( super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) - self.offload_inference_models = offload_inference_models + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef) self.actor = actor self.critic = critic + self.tokenizer = tokenizer self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) @@ -104,58 +104,81 @@ def __init__( self.actor_optim = actor_optim self.critic_optim = critic_optim + self.offload_inference_models = offload_inference_models self.device = get_current_device() + def _before_fit( + self, + prompt_dataloader: DataLoader, + pretrain_dataloader: DataLoader, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.prompt_dataloader = CycledDataLoader(prompt_dataloader) + self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-ppo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "ppo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + def _make_experience(self, collect_step: int) -> Experience: prompts = self.prompt_dataloader.next() if self.offload_inference_models: # TODO(ver217): this may be controlled by strategy if they are prepared by strategy self.experience_maker.initial_model.to(self.device) self.experience_maker.reward_model.to(self.device) - if isinstance(prompts, Tensor): - return self.experience_maker.make_experience(prompts, **self.generate_kwargs) - elif isinstance(prompts, dict): - return self.experience_maker.make_experience(**prompts, **self.generate_kwargs) - else: - raise ValueError(f'Unsupported input type "{type(prompts)}"') + assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"' + return self.experience_maker.make_experience(**prompts, **self.generate_kwargs) - def _training_step(self, experience: Experience) -> Dict[str, float]: + def _training_step(self, experience: Experience): self.actor.train() self.critic.train() # policy loss - num_actions = experience.action_mask.size(1) - actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) - action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) + num_actions = experience.action_log_probs.size(1) + actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"] + action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) actor_loss = self.actor_loss_fn( action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask ) + actor_loss = (1 - self.ptx_coef) * actor_loss + self.strategy.backward(actor_loss, self.actor, self.actor_optim) # ptx loss if self.ptx_coef != 0: batch = self.pretrain_dataloader.next() batch = to_device(batch, self.device) - ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"] - ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"]) - actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) + ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"] + ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"]) + self.strategy.backward(ptx_loss, self.actor, self.actor_optim) - self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim) self.actor_optim.zero_grad() # value loss - values = self.critic( - experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask - ) - critic_loss = self.critic_loss_fn( - values, experience.values, experience.reward, action_mask=experience.action_mask - ) + values = self.critic(experience.sequences, attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, experience.values, experience.reward) critic_loss = critic_loss * self.vf_coef self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {"reward": experience.reward.mean().item()} - def _learn(self, update_step: int): if self.offload_inference_models: self.experience_maker.initial_model.to("cpu") @@ -166,8 +189,8 @@ def _learn(self, update_step: int): experience = self.data_buffer.sample() self._on_learn_batch_start() experience.to_device(self.device) - metrics = self._training_step(experience) - self._on_learn_batch_end(metrics, experience) + self._training_step(experience) + self._on_learn_batch_end(experience) else: if isinstance(self.dataloader.sampler, DistributedSampler): self.dataloader.sampler.set_epoch(update_step) @@ -175,6 +198,5 @@ def _learn(self, update_step: int): for experience in pbar: self._on_learn_batch_start() experience.to_device(self.device) - metrics = self._training_step(experience) - self._on_learn_batch_end(metrics, experience) - pbar.set_postfix(metrics) + self._training_step(experience) + self._on_learn_batch_end(experience) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index a5d6974b3238..d7f8c21a5a3d 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -1,7 +1,5 @@ -from datetime import datetime -from typing import Callable +from typing import Callable, Optional -import pandas as pd import torch import tqdm from torch.optim import Optimizer @@ -40,10 +38,12 @@ def __init__( self.loss_fn = loss_fn self.scheduler = lr_scheduler + self.num_train_step = 0 + def _eval(self, epoch): if self.eval_dataloader is not None: self.model.eval() - dist, on, cnt = 0, 0, 0 + dist, num_correct, num_samples = 0, 0, 0 with torch.no_grad(): for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) @@ -52,27 +52,21 @@ def _eval(self, epoch): r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask) - for i in range(len(chosen_reward)): - cnt += 1 - if chosen_reward[i] > reject_reward[i]: - on += 1 + num_samples += chosen_ids.size(0) + num_correct += (chosen_reward > reject_reward).sum().item() dist += (chosen_reward - reject_reward).mean().item() self.dist = dist / len(self.eval_dataloader) - self.acc = on / cnt + self.acc = num_correct / num_samples - if is_rank_0(): - log = pd.DataFrame( - [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]], - columns=["step", "loss", "dist", "acc"], - ) - log.to_csv("log.csv", mode="a", header=False, index=False) + if self.writer: + self.writer.add_scalar("eval/dist", self.dist, epoch) + self.writer.add_scalar("eval/acc", self.acc, epoch) def _train(self, epoch): self.model.train() step_bar = tqdm.trange( - len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0() + len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0() ) - cnt = 0 for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) @@ -80,26 +74,50 @@ def _train(self, epoch): r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask) - self.loss = self.loss_fn(chosen_reward, reject_reward) - self.strategy.backward(self.loss, self.model, self.optimizer) + loss = self.loss_fn(chosen_reward, reject_reward) + self.strategy.backward(loss, self.model, self.optimizer) self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() - cnt += 1 - if cnt % 100 == 0: + if self.writer: + self.writer.add_scalar("train/loss", loss.item(), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step) + self.writer.add_scalar( + "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step + ) + self.num_train_step += 1 + if self.num_train_step % 100 == 0: self.scheduler.step() step_bar.update() step_bar.close() - def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader): + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: DataLoader, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): """ Args: train_dataloader (DataLoader): the dataloader to use for training - valid_dataloader (DataLoader): the dataloader to use for validation eval_dataloader (DataLoader): the dataloader to use for evaluation """ - super()._before_fit() - self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - self.train_dataloader = train_dataloader - self.valid_dataloader = valid_dataloader self.eval_dataloader = eval_dataloader + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-rm", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "rm") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index 8deefc2c484e..7d0eeec897e5 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -1,10 +1,8 @@ -import time from typing import Optional import torch import torch.distributed as dist import tqdm -import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader @@ -48,38 +46,34 @@ def __init__( self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler + self.num_train_step = 0 + self.num_eval_step = 0 + def _train(self, epoch: int): self.model.train() - for batch_id, batch in enumerate(self.train_dataloader): + step_bar = tqdm.trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): batch = to_device(batch, torch.cuda.current_device()) - if "attention_mask" in batch: - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) - else: - outputs = self.model(batch["input_ids"], labels=batch["labels"]) - - loss = outputs.loss - loss = loss / self.accumulation_steps - - self.strategy.backward(loss, self.model, self.optimizer) - + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + loss = outputs.loss / self.accumulation_steps self.total_loss += loss.item() - + self.strategy.backward(loss, self.model, self.optimizer) # gradient accumulation - if (batch_id + 1) % self.accumulation_steps == 0: + if (i + 1) % self.accumulation_steps == 0: self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() self.scheduler.step() - if is_rank_0() and self.use_wandb: - wandb.log( - { - "loss": self.total_loss / self.accumulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id, - } - ) + if self.writer: + self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 self.total_loss = 0 - self.step_bar.update() + step_bar.update() + step_bar.close() def _eval(self, epoch: int): if self.eval_dataloader is not None: @@ -91,20 +85,21 @@ def _eval(self, epoch: int): outputs = self.model( batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] ) - loss = outputs.loss - - loss_sum += loss.item() + loss_sum += outputs.loss.item() num_seen += batch["input_ids"].size(0) - loss_mean = loss_sum / num_seen if dist.get_rank() == 0: self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") + if self.writer: + self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step) + self.num_eval_step += 1 def _before_fit( self, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, logger: Optional[DistributedLogger] = None, + log_dir: Optional[str] = None, use_wandb: bool = False, ): """ @@ -116,15 +111,20 @@ def _before_fit( self.eval_dataloader = eval_dataloader self.logger = logger - self.use_wandb = use_wandb - if use_wandb: - wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) - wandb.watch(self.model) + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-sft", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "sft") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) self.total_loss = 0 - self.no_epoch_bar = True - self.step_bar = tqdm.trange( - len(self.train_dataloader) // self.accumulation_steps * self.max_epochs, - desc=f"steps", - disable=not is_rank_0(), - ) diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 4706f9699c91..3018ca43061e 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -1,17 +1,13 @@ import warnings from typing import Optional -import torch -import torch.distributed as dist import torch.nn as nn import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin -from colossalai.booster.plugin.gemini_plugin import GeminiModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -65,14 +61,11 @@ def __init__( assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' plugin_initializer = lambda: LowLevelZeroPlugin( - # zero_config stage=stage, precision=precision, - # zero_optim_config reduce_bucket_size_in_m=reduce_bucket_size, overlap_communication=overlap_communication, cpu_offload=(placement_policy == "cpu"), - # optim_config initial_scale=initial_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, @@ -136,7 +129,7 @@ def __init__( self, seed: int = 42, shard_init: bool = False, # only for stage 3 - placement_policy: str = "cuda", + placement_policy: str = "auto", pin_memory: bool = True, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3 search_range_m: int = 32, # only for stage 3 @@ -153,8 +146,6 @@ def __init__( max_norm: float = 0.0, norm_type: float = 2.0, ) -> None: - assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' - # TODO(ver217): support shard_init when using from_pretrained() if shard_init: warnings.warn( @@ -167,8 +158,7 @@ def __init__( # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( - # gemini_config - device=get_current_device(), + chunk_init_device=get_current_device(), placement_policy=placement_policy, precision="fp16", pin_memory=pin_memory, @@ -177,9 +167,7 @@ def __init__( search_range_m=search_range_m, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m, - # zero_optim_config gpu_margin_mem_ratio=gpu_margin_mem_ratio, - # optim_config initial_scale=initial_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, @@ -200,15 +188,8 @@ def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) def model_init_context(self): - world_size = dist.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None - default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext( - device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec - ) + return LazyInitContext(default_device=get_current_device()) def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, GeminiModel) - ddp_model = model.unwrap() - assert isinstance(ddp_model, GeminiDDP) - return ddp_model.module + assert isinstance(model, GeminiDDP) + return model.module diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index 087c49564e43..62e06bf7b3bb 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -45,9 +45,17 @@ def eval(args): raise ValueError(f'Unsupported model "{args.model}"') actor.eval() + tokenizer.padding_side = "left" input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) outputs = generate( - actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1 + actor, + input_ids, + tokenizer=tokenizer, + max_length=args.max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, ) output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) print(f"[Output]: {''.join(output)}") diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index d3ea7b0c8142..a7cfb5da7fe1 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,3 +1,3 @@ pandas>=1.4.1 sentencepiece -colossalai==0.3.1 +colossalai>=0.3.1 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index ad688b07a7f2..de2a33263040 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -23,7 +23,7 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: @@ -65,8 +65,8 @@ def main(args): if args.rm_path is not None: reward_model.load_state_dict(state_dict, strict=False) - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) + initial_model.to(torch.bfloat16).to(torch.cuda.current_device()) + reward_model.to(torch.bfloat16).to(torch.cuda.current_device()) if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) @@ -80,13 +80,13 @@ def main(args): raise ValueError(f'Unsupported actor model "{args.model}"') if rm_model_name == "gpt2": - critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == "bloom": - critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == "opt": - critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == "llama": - critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -94,17 +94,16 @@ def main(args): critic.load_state_dict(state_dict, strict=False) del state_dict - if args.strategy != "colossalai_gemini": - critic.to(torch.float16).to(torch.cuda.current_device()) - actor.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.bfloat16).to(torch.cuda.current_device()) + critic.to(torch.bfloat16).to(torch.cuda.current_device()) # configure optimizer if args.strategy.startswith("colossalai"): - actor_optim = HybridAdam(actor.parameters(), lr=1e-7) - critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + actor_optim = HybridAdam(actor.parameters(), lr=args.lr) + critic_optim = HybridAdam(critic.parameters(), lr=args.lr) else: - actor_optim = Adam(actor.parameters(), lr=1e-7) - critic_optim = Adam(critic.parameters(), lr=1e-7) + actor_optim = Adam(actor.parameters(), lr=args.lr) + critic_optim = Adam(critic.parameters(), lr=args.lr) # configure tokenizer if args.model == "gpt2": @@ -126,8 +125,15 @@ def main(args): tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - - prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) + # NOTE: generate() requires padding_side to be "left" + tokenizer.padding_side = "left" + + prompt_dataset = PromptDataset( + tokenizer=tokenizer, + data_path=args.prompt_dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_input_len, + ) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: @@ -137,7 +143,10 @@ def main(args): ) pretrain_dataset = SupervisedDataset( - tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len + tokenizer=tokenizer, + data_path=args.pretrain_dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_input_len, ) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) @@ -161,6 +170,7 @@ def main(args): initial_model, actor_optim, critic_optim, + tokenizer=tokenizer, kl_coef=args.kl_coef, ptx_coef=args.ptx_coef, train_batch_size=args.train_batch_size, @@ -169,17 +179,17 @@ def main(args): do_sample=True, temperature=1.0, top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, offload_inference_models=args.strategy != "colossalai_gemini", ) trainer.fit( - prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, num_episodes=args.num_episodes, num_collect_steps=args.num_collect_steps, num_update_steps=args.num_update_steps, + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, ) # save model checkpoint after fitting @@ -195,6 +205,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument("--max_datasets_size", type=int, default=50000) parser.add_argument( "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], @@ -216,9 +227,12 @@ def main(args): parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--lr", type=float, default=1e-7) parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.9) parser.add_argument("--max_input_len", type=int, default=96) parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index a07f4b5ca812..c9095b365884 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -1,5 +1,4 @@ import argparse -from random import randint import torch import torch.distributed as dist @@ -27,7 +26,7 @@ def train(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda") + strategy = GeminiStrategy(placement_policy="auto") elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: @@ -46,7 +45,7 @@ def train(args): else: raise ValueError(f'Unsupported model "{args.model}"') - model.to(torch.float16).to(torch.cuda.current_device()) + model.to(torch.bfloat16).to(torch.cuda.current_device()) if args.model_path is not None: state_dict = torch.load(args.model_path) @@ -75,9 +74,9 @@ def train(args): # configure optimizer if args.strategy.startswith("colossalai"): - optim = HybridAdam(model.parameters(), lr=5e-6) + optim = HybridAdam(model.parameters(), lr=args.lr) else: - optim = Adam(model.parameters(), lr=5e-6) + optim = Adam(model.parameters(), lr=args.lr) # configure loss function if args.loss_fn == "log_sig": @@ -93,21 +92,14 @@ def train(args): else: data = load_dataset(args.dataset) - if args.test: - train_data = data["train"].select(range(20)) - eval_data = data["test"].select(range(5)) - else: - train_data = data["train"] - eval_data = data["test"] - valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) + train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"])))) + eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"])))) if args.dataset == "Dahoas/rm-static": train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) - valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) elif args.dataset == "Anthropic/hh-rlhf": train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) - valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) else: raise ValueError(f'Unsupported dataset "{args.dataset}"') @@ -121,14 +113,6 @@ def train(args): rank=dist.get_rank(), num_replicas=dist.get_world_size(), ) - valid_sampler = DistributedSampler( - valid_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - ) eval_sampler = DistributedSampler( eval_dataset, shuffle=True, @@ -139,7 +123,6 @@ def train(args): ) else: train_sampler = None - valid_sampler = None eval_sampler = None train_dataloader = DataLoader( @@ -150,14 +133,6 @@ def train(args): pin_memory=True, ) - valid_dataloader = DataLoader( - valid_dataset, - shuffle=(valid_sampler is None), - sampler=valid_sampler, - batch_size=args.batch_size, - pin_memory=True, - ) - eval_dataloader = DataLoader( eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True ) @@ -176,7 +151,12 @@ def train(args): max_epochs=args.max_epochs, ) - trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) + trainer.fit( + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks @@ -200,12 +180,15 @@ def train(args): "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" ) parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) + parser.add_argument("--max_datasets_size", type=int, default=1000000) parser.add_argument("--save_path", type=str, default="rm_ckpt") parser.add_argument("--max_epochs", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) - parser.add_argument("--test", type=bool, default=False) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh index cc1b7be2815f..c5ebaf708ddc 100755 --- a/applications/Chat/examples/train_rm.sh +++ b/applications/Chat/examples/train_rm.sh @@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 torchrun --standalone --nproc_per_node=2 train_reward_model.py \ - --model 'bloom' \ + --pretrain 'gpt2' \ + --model 'gpt2' \ --strategy colossalai_zero2 \ - --loss_fn 'log_sig' \ - --dataset 'Anthropic/hh-rlhf' + --loss_fn 'log_exp' \ + --dataset 'Anthropic/hh-rlhf' \ + --batch_size 16 \ + --max_epochs 10 diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 1729abb86a09..a34661762258 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -23,7 +23,6 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter def train(args): @@ -31,7 +30,7 @@ def train(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda") + strategy = GeminiStrategy(placement_policy="auto") elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif args.strategy == "colossalai_zero2_cpu": @@ -57,7 +56,7 @@ def train(args): else: raise ValueError(f'Unsupported model "{args.model}"') - model.to(torch.float16).to(torch.cuda.current_device()) + model.to(torch.bfloat16).to(torch.cuda.current_device()) # configure tokenizer if args.model == "gpt2": @@ -84,28 +83,21 @@ def train(args): else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == "llama" and args.strategy == "colossalai_gemini": - # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility - for name, param in model.named_parameters(): - if not isinstance(param, ColoParameter): - sub_module_name = ".".join(name.split(".")[:-1]) - weight_name = name.split(".")[-1] - sub_module = model.get_submodule(sub_module_name) - setattr(sub_module, weight_name, ColoParameter(param)) - # configure optimizer if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) - logger = get_dist_logger() # configure dataset if args.dataset == "yizhongw/self_instruct": train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") + if args.max_datasets_size is not None: + train_data = train_data.select(range(min(args.max_datasets_size, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data)))) + train_dataset = SFTDataset(train_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) @@ -176,8 +168,13 @@ def train(args): accumulation_steps=args.accumulation_steps, ) + logger = get_dist_logger() trainer.fit( - train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + logger=logger, + log_dir=args.log_dir, + use_wandb=args.use_wandb, ) # save model checkpoint after fitting on only rank0 @@ -207,9 +204,9 @@ def train(args): parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") args = parser.parse_args() diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index 1a5cd069011d..0fb4da3d3ce8 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2 \ - --log_interval 10 \ --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 4 \ diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index 809fbd4bb86b..adf2cc1bf545 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1,2 +1,2 @@ pytest -colossalai==0.3.1 +colossalai>=0.3.1 diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index e5f5ca0932a8..a784ccbe0d3a 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai==0.3.1 +colossalai>=0.3.1 torch<2.0.0, >=1.12.1 langchain tokenizers @@ -11,3 +11,4 @@ sse_starlette wandb sentencepiece gpustat +tensorboard diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index e3058be2e67c..9dfaa7c88206 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): data = get_data(batch_size) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) - actor_output = actor(data["input_ids"], data["attention_mask"]) - action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1)) + actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"] + action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1)) loss = action_log_probs.sum() strategy.backward(loss, actor, actor_optim) strategy.optimizer_step(actor_optim) @@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool): if strategy_name == "ddp": strategy = DDPStrategy() elif strategy_name == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) elif strategy_name == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index 3de2cc528967..ec61bbb13fd7 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: check_content(input_ids.masked_select(attention_mask), tokenizer, model) assert torch.all(attention_mask) ignore_mask = labels == IGNORE_INDEX - check_content(input_ids.masked_select(ignore_mask), tokenizer, model) + prompt_mask = torch.logical_and(ignore_mask, attention_mask) + check_content(input_ids.masked_select(prompt_mask), tokenizer, model) + assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id) if __name__ == "__main__": diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py index d0ea3bbd2ff5..a9591259800d 100644 --- a/applications/Chat/tests/test_experience.py +++ b/applications/Chat/tests/test_experience.py @@ -1,5 +1,5 @@ +import copy import os -from copy import deepcopy import pytest import torch @@ -8,6 +8,7 @@ from coati.experience_maker import NaiveExperienceMaker from coati.models.base import RewardModel from coati.models.gpt import GPTActor, GPTCritic +from coati.trainer.ppo import _set_default_generate_kwargs from coati.trainer.strategies import DDPStrategy, GeminiStrategy from coati.trainer.strategies.colossalai import LowLevelZeroStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config @@ -42,27 +43,38 @@ def make_and_consume_experience(strategy): elif strategy == "colossalai-zero2": strategy = LowLevelZeroStrategy() elif strategy == "colossalai-gemini": - strategy = GeminiStrategy(placement_policy="cuda") + strategy = GeminiStrategy(placement_policy="static") else: raise ValueError(f'Unsupported strategy "{strategy}"') - actor = GPTActor(config=GPT_CONFIG).cuda() - critic = GPTCritic(config=GPT_CONFIG).cuda() + with strategy.model_init_context(): + actor = GPTActor(config=GPT_CONFIG).cuda() + critic = GPTCritic(config=GPT_CONFIG).cuda() - initial_model = deepcopy(actor) - reward_model = RewardModel(deepcopy(critic.model)).cuda() + initial_model = GPTActor(config=GPT_CONFIG).cuda() + reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda() - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) + actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model) + + class MockTokenizer: + def __init__(self): + self.padding_side = "left" + self.eos_token_id = 0 + self.pad_token_id = 0 + + tokenizer = MockTokenizer() + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer) data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) + generate_kwargs = dict(do_sample=True, max_length=16) + generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + # experience of all ranks should be the same for _ in range(2): data = get_data(EXPERIENCE_BATCH_SIZE) assert gather_and_equal(data["input_ids"]) assert gather_and_equal(data["attention_mask"]) - experience = experience_maker.make_experience( - **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256 - ) + experience = experience_maker.make_experience(**data, do_sample=True, max_length=16) assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.values) @@ -115,4 +127,4 @@ def test_experience(world_size, strategy): if __name__ == "__main__": - test_experience(2, "colossalai") + test_experience(2, "colossalai-zero2") diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index b2551ff5c0de..b2c22ac6a3b9 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -14,7 +14,7 @@ from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.opt import OPTRM, OPTActor, OPTCritic -from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean +from coati.models.utils import calc_action_log_probs, masked_mean @pytest.mark.parametrize("batch_size", [4]) @@ -27,7 +27,6 @@ # HACK: skip llama due to long execution time # lambda: LlamaActor(), lambda: OPTActor(), - # lambda: ChatGLMActor(), ], ) @pytest.mark.parametrize( @@ -43,9 +42,16 @@ ], ) def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): + class MockTokenizer: + def __init__(self): + self.padding_side = "left" + self.eos_token_id = 0 + self.pad_token_id = 0 + actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() - sequences = generate(actor.cuda(), input_ids, **generate_kwargs) + tokenizer = MockTokenizer() + sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs) assert sequences.shape == (batch_size, generate_kwargs["max_length"]) @@ -55,24 +61,12 @@ def test_utils(): assert fn_output.dim() == 0 assert torch.allclose(fn_output, torch.tensor(1.0)) - batch_size = 4 - num_labels = 10 - fn_input = { - "r": torch.ones((batch_size,)), - "kl_coef": 1.0, - "log_probs": torch.randn((batch_size, num_labels)), - "log_probs_base": torch.randn((batch_size, num_labels)), - "action_mask": torch.randint(0, 2, (batch_size, num_labels)), - } - fn_output = compute_reward(**fn_input) - assert fn_output.shape == (batch_size,) - batch_size = 4 seq_len = 32 num_labels = 10 num_actions = 2 fn_input = { - "output": {"logits": torch.randn((batch_size, seq_len, num_labels))}, + "logits": torch.randn((batch_size, seq_len, num_labels)), "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "num_actions": num_actions, } @@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b } critic_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), - "action_mask": torch.randint(0, 2, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } rm_input = { diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh index c5127c188612..55de269005ed 100755 --- a/applications/Chat/tests/test_train.sh +++ b/applications/Chat/tests/test_train.sh @@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then exit 1 fi -if [ -z "$PROMPT_PATH" ]; then - echo "Please set \$PROMPT_PATH to the path to prompts csv." +if [ -z "$PROMPT_DATASET" ]; then + echo "Please set \$PROMPT_DATASET to the path to prompts csv." exit 1 fi @@ -74,11 +74,15 @@ echo "[Test]: testing sft ..." # FIXME: This is a hack to skip tests that are not working # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet SKIPPED_TESTS=( "gpt2-ddp" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" + "gpt2-colossalai_gemini" + "opt-colossalai_gemini" + "bloom-colossalai_gemini" ) GRAD_CKPTS=('' '--grad_checkpoint') @@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do $pretrain_model --tokenizer $MODELS_DIR/$model \ --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ --dataset $SFT_DATASET --max_datasets_size 8 \ - --max_epochs 1 --batch_size 1 --accumulation_steps 1 \ + --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \ --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} passed=$? if [ $passed -eq 0 ]; then @@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..." # FIXME: This is a hack to skip tests that are not working # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet SKIPPED_TESTS=( "gpt2-ddp" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" + "gpt2-colossalai_gemini" + "opt-colossalai_gemini" + "bloom-colossalai_gemini" ) LOSS_FNS=('log_sig' 'log_exp') @@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do echo "[Test]: $model-$strategy-$lora_rank, attempt $i" torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ $pretrain_model --tokenizer $MODELS_DIR/$model \ - --model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \ - --dataset $dataset --subset $subset --test True --batch_size 1 \ + --dataset $dataset --subset $subset --max_datasets_size 8 \ + --model $model --strategy $strategy --lora_rank $lora_rank \ + --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \ --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt passed=$? if [ $passed -eq 0 ]; then @@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..." # FIXME: This is a hack to skip tests that are not working # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet SKIPPED_TESTS=( "gpt2-ddp" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" + "gpt2-colossalai_gemini" + "opt-colossalai_gemini" + "bloom-colossalai_gemini" ) for model in ${MODELS[@]}; do @@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do for i in $(seq $NUM_RETRY); do echo "[Test]: $model-$strategy-$lora_rank, attempt $i" torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ - --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ + --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \ --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ - --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \ + --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \ --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 521527da51e0..e811e1acbf7e 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -3,6 +3,7 @@ import pytest import torch +from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map import colossalai @@ -13,7 +14,6 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn from colossalai.utils import get_current_device -from model_zoo import GPTLMLoss, get_gpt2_components def parse_args(): diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index 17692e90a03c..09bbae9c5b74 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -3,6 +3,7 @@ from functools import partial import torch +from model_zoo import model_builder from torch import nn from colossalai.fx import ColoTracer @@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger -from model_zoo import model_builder def parse_args(): From c0a033700c7027286ecd8a7bcbaafc6f794323ad Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 20 Sep 2023 18:29:37 +0800 Subject: [PATCH 30/58] [shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758) * fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs --- .../naive_amp/mixed_precision_optimizer.py | 17 +++- colossalai/booster/booster.py | 2 +- colossalai/booster/plugin/gemini_plugin.py | 23 +++-- .../booster/plugin/hybrid_parallel_plugin.py | 22 +++-- .../booster/plugin/low_level_zero_plugin.py | 46 +++------- colossalai/booster/plugin/torch_ddp_plugin.py | 63 +++++++++++--- .../booster/plugin/torch_fsdp_plugin.py | 16 ++-- .../checkpoint_io/checkpoint_io_base.py | 6 -- .../checkpoint_io/general_checkpoint_io.py | 11 --- .../hybrid_parallel_checkpoint_io.py | 83 +++++-------------- colossalai/checkpoint_io/utils.py | 9 -- colossalai/zero/gemini/gemini_ddp.py | 4 - colossalai/zero/low_level/low_level_optim.py | 6 ++ .../test_plugins_huggingface_compatibility.py | 4 +- 14 files changed, 141 insertions(+), 171 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 6a192cc5cb83..501a843f6992 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -2,7 +2,7 @@ import torch from torch import Tensor -from torch.nn import Parameter +from torch.nn import Module, Parameter from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -152,3 +152,18 @@ def step(self, *args, **kwargs): if p is working_param: continue working_param.data.copy_(p.data) + + def update_master_params(self, model: Module): + # Update master params from working params + with torch.no_grad(): + for p in model.parameters(): + if (p is None) or (p not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[p] + master_param.data.copy_(p.data) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()} + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 2aee72cbf2f1..8d6b0b42e545 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -139,7 +139,7 @@ def boost( if self.plugin and not self.plugin.control_device(): # transform model for accelerator - model = self.accelerator.configure(model) + model = self.accelerator.configure_model(model) if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 83a00d4ee229..abf3a907b777 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -44,6 +44,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor The model should be unwrapped in self.load_model via ModelWrapper.unwrap. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors) @@ -53,24 +54,27 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = Load model from checkpoint with automatic unwrapping. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): """ Save unsharded optimizer state dict to checkpoint. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. The saving process will only be executed by master rank. """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" state_dict = optimizer.state_dict() if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): """ Loading unsharded optimizer from checkpoint file. For each process, only loading optimizer states of parameters it controls. """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" super().load_unsharded_optimizer(optimizer, checkpoint) def save_sharded_model( @@ -86,6 +90,7 @@ def save_sharded_model( Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return @@ -111,7 +116,7 @@ def save_sharded_model( if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - save_config_file(model.module, checkpoint_path) + save_config_file(model.unwrap(), checkpoint_path) logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -124,17 +129,17 @@ def load_sharded_model( """ Load shard model, load model from multiple files. """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int ): """ Save sharded optimizer state dict to checkpoint folder. As there is communication when getting state dict, this must be called on all processes. """ - - assert isinstance(optimizer, GeminiOptimizer) + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -176,12 +181,12 @@ def save_sharded_optimizer( f"index located at {save_index_file}." ) - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): + def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): """ Loading sharded optimizer from checkpoint folder, with index file given. For each process, only loading optimizer states of parameters it controls. """ - + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): logging.error(f"Provided path ({checkpoint_index_file}) should be a file") @@ -383,7 +388,7 @@ def configure( if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c1693fa8d3a1..46930887bf9c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,7 @@ import random from contextlib import nullcontext from functools import partial +from types import MethodType from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np @@ -165,6 +166,15 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_in init_pipeline_optimizer(optim, model) super().__init__(optim) + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__( @@ -466,9 +476,6 @@ def configure( max_norm=self.max_norm, **self.amp_config, ) - self.checkpoint_io.link_master_and_working_param( - optimizer.working_to_master_map, optimizer.master_to_working_map - ) else: optimizer = HybridParallelNaiveOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info @@ -488,10 +495,8 @@ def configure( **self.zero_config, **self.amp_config, ) - self.checkpoint_io.link_master_and_working_param( - optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param - ) - + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline( @@ -567,8 +572,7 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> CheckpointIO: - self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - return self.checkpoint_io + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 86adee7fe226..457c720f6418 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -22,7 +22,6 @@ save_param_groups, save_state_dict, sharded_optimizer_loading_epilogue, - unwrap_optimizer, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device @@ -65,10 +64,6 @@ def forward(self, *args, **kwargs): kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) - def unwrap(self): - # TODO(ver217): this is a workaround for loading model - return self - class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): @@ -79,7 +74,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, checkpoint (str): Path to save checkpoint gather_dtensor (bool): Whether to gather_dtensor, not used """ - + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match @@ -109,6 +104,7 @@ def save_sharded_optimizer( prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file that store state tensors """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -160,9 +156,8 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s index_file_path (str): Path to the index file prefix (str): Not used. """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!" + optimizer = optimizer.unwrap() # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -194,44 +189,23 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s v_list = v.split(v.numel() // self.coordinator.world_size) state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() load_states_into_optimizer(optimizer, state_dict, id_map) - sharded_optimizer_loading_epilogue(optimizer) - def save_unsharded_model( - self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool - ): - assert isinstance(model, LowLevelZeroModel) - super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) - - def save_sharded_model( - self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False, - ): - assert isinstance(model, LowLevelZeroModel) - super().save_sharded_model( - model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors - ) - - def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): - assert isinstance(model, LowLevelZeroModel) - super().load_unsharded_model(model.module, checkpoint, strict) + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_unsharded_model(model, checkpoint, strict) model.update_master_params() def load_sharded_model( self, - model: LowLevelZeroModel, + model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, ): - assert isinstance(model, LowLevelZeroModel) - super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 30d34e7dd5e5..41d7c0635bf6 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -20,24 +20,33 @@ def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ - Load model from checkpoint with automatic unwrapping. + Load model from checkpoint. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) @@ -50,7 +59,7 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_sharded_model( self, - model: nn.Module, + model: ModelWrapper, checkpoint_path: str, gather_dtensor: bool = True, prefix: Optional[str] = None, @@ -60,22 +69,52 @@ def save_sharded_model( """ Save model to checkpoint but only on master process. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) + super().save_sharded_model( + model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: str, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model from sharded checkpoint. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) def save_sharded_optimizer( self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, ): """ - Save optimizer to checkpoint but only on master process. + Save optimizer to sharded checkpoint but only on master process. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: Optional[str] = None, + ): + """ + Load optimizer from sharded checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d12b784b4fc1..1e3762b79016 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -39,31 +39,35 @@ def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + model = model.unwrap() checkpoint = utils.load_state_dict(checkpoint) model.load_state_dict(checkpoint) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" checkpoint = utils.load_state_dict(checkpoint) fsdp_model = optimizer.unwrap_model() sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + model = model.unwrap() cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ - assert isinstance(optimizer, FSDPOptimizerWrapper) + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index f8ce8f4e5210..780117598e18 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -87,9 +87,6 @@ def load_model( # return the origin model instead of the unwrapped model origin_model = model - if isinstance(model, ModelWrapper): - model = model.unwrap() - if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: @@ -134,9 +131,6 @@ def save_model( use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ - if isinstance(model, ModelWrapper): - model = model.unwrap() - if shard: self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index b0e593e90d8c..a652d9b4538e 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,8 +8,6 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import OptimizerWrapper - from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -28,7 +26,6 @@ shard_model_checkpoint, shard_optimizer_checkpoint, sharded_optimizer_loading_epilogue, - unwrap_optimizer, ) __all__ = ["GeneralCheckpointIO"] @@ -58,10 +55,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre Load sharded optimizer with the given path to index file. """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) - # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -98,10 +91,6 @@ def save_sharded_optimizer( - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) - if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 18c59a880dd6..41e53b3b388f 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -3,7 +3,7 @@ import os from pathlib import Path from shutil import rmtree -from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -71,8 +71,6 @@ def __init__( self.tp_size = dist.get_world_size(tp_group) self.use_zero = zero_stage > 0 self.verbose = verbose - self.working_to_master_map = None - self.master_to_working_map = None self.coordinator = DistCoordinator() @staticmethod @@ -159,7 +157,7 @@ def _optimizer_sharder( def save_sharded_model( self, - model: nn.Module, + model: ModelWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, @@ -184,6 +182,9 @@ def save_sharded_model( use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -279,7 +280,7 @@ def save_sharded_model( f"index located at {final_index_file_path}." ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): """ Load sharded model with the given path to index file of checkpoint folder. @@ -289,6 +290,9 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model_before_wrapping = model # backup for model before wrapping + model = model.unwrap() # Check whether the checkpoint uses safetensors. use_safetensors = False @@ -347,23 +351,7 @@ def _load(name: str): _load(extra_state_key) # Update master params if mixed-precision training is enabled. - with torch.no_grad(): - if self.working_to_master_map is not None: - for param in model.parameters(): - if (param is None) or (id(param) not in self.working_to_master_map): - continue - master_param = self.working_to_master_map[id(param)] - if self.use_zero: - # master_param is sharded under Zero setting - padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size - if padding_size > 0: - padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - else: - padded_param = param.data.view(-1) - sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] - master_param.data.copy_(sharded_param.data) - else: - master_param.data.copy_(param.data) + model_before_wrapping.update_master_params() if self.verbose: logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -392,6 +380,7 @@ def save_sharded_optimizer( prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -410,7 +399,7 @@ def save_sharded_optimizer( use_zero=self.use_zero, dp_group=self.dp_group, tp_group=self.tp_group, - master_to_working_map=self.master_to_working_map, + master_to_working_map=optimizer.get_master_to_working_map(), size_per_shard=size_per_shard, ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) @@ -511,6 +500,7 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f checkpoint_index_file (str): Path to the index file of checkpointing folder. prefix (str): Not used. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None @@ -525,9 +515,10 @@ def _get_param_id_from_optimizer_param( # When Zero is used, the mapped parameter objects should be fp32 master parameters. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) id_map[param_id] = param # Read checkpoint index file. @@ -560,7 +551,7 @@ def _get_param_id_from_optimizer_param( for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -577,8 +568,8 @@ def _get_param_id_from_optimizer_param( # Then shard the loaded optimizer states if using tp/zero. for param, state in optimizer.optim.state.items(): device = param.device - if self.master_to_working_map is not None: - working_param = self.master_to_working_map[id(param)] + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] else: working_param = param original_shape = optimizer.param_info["param2shape"][id(working_param)] @@ -614,42 +605,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def link_master_and_working_param( - self, - working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], - master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor], - ): - """ - Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. - This mapping can only be created when mixied precision is used. - The created mappings should be mappings from integer parameter addresses to parameter objects. - - Args: - working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. - master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. - """ - self.working_to_master_map = dict() - for k, v in working_to_master_map.items(): - if isinstance(k, torch.Tensor): - self.working_to_master_map[id(k)] = v - elif isinstance(k, int): - self.working_to_master_map[k] = v - else: - raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" - ) - - self.master_to_working_map = dict() - for k, v in master_to_working_map.items(): - if isinstance(k, torch.Tensor): - self.master_to_working_map[id(k)] = v - elif isinstance(k, int): - self.master_to_working_map[k] = v - else: - raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" - ) - @staticmethod def gather_from_sharded_optimizer_state( state: OrderedDict, diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index c22b76dd46f7..d2f4a0bcacf8 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -11,7 +11,6 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import OptimizerWrapper from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz # ====================================== # Helper classes and functions for saving shard file # ====================================== -def unwrap_optimizer(optimizer: OptimizerWrapper): - """ - Unwrap a wrapped optimizer. - This method should be used before saving/loading it to/from sharded checkpoints. - """ - - unwrapped_optim = optimizer.optim - return unwrapped_optim class StateDictSharder: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 580b497ce719..0ba9e53cfcd6 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -186,10 +186,6 @@ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: for p in params_to_ignore: p._ddp_to_ignore = True - def unwrap(self): - # as save/load state dict is overwrited, only return self - return self - def _get_non_persistent_buffers_set( self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True ): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 1bf5302efcfb..72df93ace302 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -648,3 +648,9 @@ def update_master_params(self, model: nn.Module) -> None: if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.working_to_master_param + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.master_to_working_param diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index c3c30e666b10..a6f67e0d7729 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) if plugin_type == "gemini": - check_state_dict_equal( - model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False - ) + check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) else: check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) dist.barrier() From df66741f77b0ea3740df64bb5f0eafb36538393f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 10:42:25 +0800 Subject: [PATCH 31/58] [bug] fix get_default_parser in examples (#4764) --- colossalai/legacy/__init__.py | 10 +++++++++- examples/community/roberta/pretraining/arguments.py | 4 ++-- examples/images/vit/args.py | 6 +++--- examples/images/vit/run_benchmark.sh | 6 +++--- examples/images/vit/run_demo.sh | 4 ++-- examples/images/vit/test_ci.sh | 4 ++-- examples/language/gpt/titans/train_gpt.py | 3 ++- examples/language/opt/args.py | 6 +++--- examples/language/opt/run_benchmark.sh | 4 ++-- examples/language/opt/run_demo.sh | 4 ++-- examples/language/opt/test_ci.sh | 4 ++-- examples/language/palm/run.sh | 2 +- examples/language/palm/test_ci.sh | 2 +- examples/language/palm/train.py | 3 ++- 14 files changed, 36 insertions(+), 26 deletions(-) diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py index 4d6ad357a2fa..678a5def5c68 100644 --- a/colossalai/legacy/__init__.py +++ b/colossalai/legacy/__init__.py @@ -1,4 +1,11 @@ -from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch +from .initialize import ( + get_default_parser, + initialize, + launch, + launch_from_openmpi, + launch_from_slurm, + launch_from_torch, +) __all__ = [ "launch", @@ -6,4 +13,5 @@ "launch_from_slurm", "launch_from_torch", "initialize", + "get_default_parser", ] diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py index 35b809d80947..3428db4cb9c5 100644 --- a/examples/community/roberta/pretraining/arguments.py +++ b/examples/community/roberta/pretraining/arguments.py @@ -1,10 +1,10 @@ -import colossalai +import argparse __all__ = ["parse_args"] def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index 7d54020f85c4..9de4743ef94d 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -1,8 +1,8 @@ -from colossalai import get_default_parser +import argparse def parse_demo_args(): - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, @@ -52,7 +52,7 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh index 41eab9c5a188..ad41a283711c 100644 --- a/examples/images/vit/run_benchmark.sh +++ b/examples/images/vit/run_benchmark.sh @@ -11,9 +11,9 @@ for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_par do MODEL_PATH="google/vit-base-patch16-224" -torchrun \ - --standalone \ - --nproc_per_node 4 \ +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ vit_benchmark.py \ --model_name_or_path ${MODEL_PATH} \ --mem_cap ${MEMCAP} \ diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh index 9efe1475956d..8eead0661454 100644 --- a/examples/images/vit/run_demo.sh +++ b/examples/images/vit/run_demo.sh @@ -35,9 +35,9 @@ WEIGHT_DECAY=0.05 WARMUP_RATIO=0.3 # run the script for demo -torchrun \ - --standalone \ +colossalai run \ --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ vit_train_demo.py \ --model_name_or_path ${MODEL} \ --output_path ${OUTPUT_PATH} \ diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh index 570147606636..fc1f2b7a2ee0 100644 --- a/examples/images/vit/test_ci.sh +++ b/examples/images/vit/test_ci.sh @@ -5,9 +5,9 @@ BS=8 for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" do -torchrun \ - --standalone \ +colossalai run \ --nproc_per_node 4 \ + --master_port 29505 \ vit_benchmark.py \ --model_name_or_path "google/vit-base-patch16-224" \ --plugin ${PLUGIN} \ diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index b9d802f01cc9..565cf1e016cc 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -1,3 +1,4 @@ +import argparse import contextlib import os @@ -29,7 +30,7 @@ def calc_local_model_size(model: torch.nn.Module): def main(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument("--from_torch", default=False, action="store_true") parser.add_argument("--use_dummy_dataset", default=False, action="store_true") args = parser.parse_args() diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py index 1ec19094e19e..fc3d42fae220 100644 --- a/examples/language/opt/args.py +++ b/examples/language/opt/args.py @@ -1,8 +1,8 @@ -from colossalai import get_default_parser +import argparse def parse_demo_args(): - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, @@ -39,7 +39,7 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh index b94ee61f277c..b79d6c13465e 100644 --- a/examples/language/opt/run_benchmark.sh +++ b/examples/language/opt/run_benchmark.sh @@ -16,9 +16,9 @@ for GPUNUM in 1 4 do MODLE_PATH="facebook/opt-${MODEL}" -torchrun \ - --standalone \ +colossalai run \ --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ opt_benchmark.py \ --model_name_or_path ${MODLE_PATH} \ --mem_cap ${MEMCAP} \ diff --git a/examples/language/opt/run_demo.sh b/examples/language/opt/run_demo.sh index 07b429cecf1e..fe49d794f4b0 100644 --- a/examples/language/opt/run_demo.sh +++ b/examples/language/opt/run_demo.sh @@ -30,9 +30,9 @@ WEIGHT_DECAY=0.01 WARMUP_RATIO=0.1 # run the script for demo -torchrun \ - --standalone \ +colossalai run \ --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ opt_train_demo.py \ --model_name_or_path ${MODEL} \ --output_path ${OUTPUT_PATH} \ diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh index fa14f52b70d2..2e3a645caf06 100644 --- a/examples/language/opt/test_ci.sh +++ b/examples/language/opt/test_ci.sh @@ -7,9 +7,9 @@ do for GPUNUM in 1 4 do -torchrun \ - --standalone \ +colossalai run \ --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ opt_benchmark.py \ --model_name_or_path "facebook/opt-125m" \ --plugin ${PLUGIN} \ diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh index 2a846e81a9a7..0b9871c77723 100644 --- a/examples/language/palm/run.sh +++ b/examples/language/palm/run.sh @@ -8,6 +8,6 @@ export PLACEMENT='cpu' export USE_SHARD_INIT=False export BATCH_SIZE=1 -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \ +env OMP_NUM_THREADS=12 colossalai run --nproc_per_node ${GPUNUM} --master_port 29505 train.py \ --dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \ --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh index 4de6a44e5bf7..6bcd140fe7fd 100644 --- a/examples/language/palm/test_ci.sh +++ b/examples/language/palm/test_ci.sh @@ -4,6 +4,6 @@ for BATCH_SIZE in 2 do for GPUNUM in 1 4 do -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log +env OMP_NUM_THREADS=12 colossalai run --nproc_per_node ${GPUNUM} --master_port 29505 train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log done done diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index e7af88c55121..7af02e24e6cf 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,3 +1,4 @@ +import argparse import gzip from contextlib import nullcontext from functools import partial @@ -33,7 +34,7 @@ def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", type=str, From 66f3926019e4c52aac94864b2b6c684c878534dd Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 21 Sep 2023 11:36:20 +0800 Subject: [PATCH 32/58] [doc] clean up outdated docs (#4765) * [doc] clean up outdated docs * [doc] fix linking * [doc] fix linking --- docs/sidebars.json | 15 +- .../advanced_tutorials/add_your_parallel.md | 125 ------ .../define_your_own_parallel_model.md | 36 -- ...parallelize_your_training_like_Megatron.md | 194 --------- docs/source/en/basics/colotensor_concept.md | 98 ----- .../en/basics/configure_parallelization.md | 158 ------- docs/source/en/basics/define_your_config.md | 85 ---- docs/source/en/basics/engine_trainer.md | 390 ------------------ docs/source/en/basics/initialize_features.md | 51 --- docs/source/en/basics/model_checkpoint.md | 64 --- docs/source/en/features/1D_tensor_parallel.md | 4 - docs/source/en/features/2D_tensor_parallel.md | 2 - .../en/features/2p5D_tensor_parallel.md | 2 - docs/source/en/features/3D_tensor_parallel.md | 2 - .../en/features/gradient_accumulation.md | 47 --- .../gradient_accumulation_with_booster.md | 3 +- docs/source/en/features/gradient_clipping.md | 64 --- .../gradient_clipping_with_booster.md | 3 +- docs/source/en/features/gradient_handler.md | 64 --- .../en/features/mixed_precision_training.md | 368 ----------------- .../mixed_precision_training_with_booster.md | 5 +- docs/source/en/features/zero_with_chunk.md | 2 +- .../advanced_tutorials/add_your_parallel.md | 113 ----- .../define_your_own_parallel_model.md | 31 -- ...parallelize_your_training_like_Megatron.md | 179 -------- .../zh-Hans/basics/colotensor_concept.md | 99 ----- .../basics/configure_parallelization.md | 138 ------- .../zh-Hans/basics/define_your_config.md | 73 ---- docs/source/zh-Hans/basics/engine_trainer.md | 387 ----------------- .../zh-Hans/basics/initialize_features.md | 48 --- .../source/zh-Hans/basics/model_checkpoint.md | 64 --- .../zh-Hans/features/1D_tensor_parallel.md | 5 +- .../zh-Hans/features/2D_tensor_parallel.md | 2 - .../zh-Hans/features/2p5D_tensor_parallel.md | 2 - .../zh-Hans/features/3D_tensor_parallel.md | 2 - .../zh-Hans/features/gradient_accumulation.md | 41 -- .../gradient_accumulation_with_booster.md | 3 +- .../zh-Hans/features/gradient_clipping.md | 53 --- .../gradient_clipping_with_booster.md | 3 +- .../zh-Hans/features/gradient_handler.md | 60 --- .../features/mixed_precision_training.md | 345 ---------------- .../mixed_precision_training_with_booster.md | 5 +- .../zh-Hans/features/zero_with_chunk.md | 2 +- 43 files changed, 12 insertions(+), 3425 deletions(-) delete mode 100644 docs/source/en/advanced_tutorials/add_your_parallel.md delete mode 100644 docs/source/en/advanced_tutorials/define_your_own_parallel_model.md delete mode 100644 docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md delete mode 100644 docs/source/en/basics/colotensor_concept.md delete mode 100644 docs/source/en/basics/configure_parallelization.md delete mode 100644 docs/source/en/basics/define_your_config.md delete mode 100644 docs/source/en/basics/engine_trainer.md delete mode 100644 docs/source/en/basics/initialize_features.md delete mode 100644 docs/source/en/basics/model_checkpoint.md delete mode 100644 docs/source/en/features/gradient_accumulation.md delete mode 100644 docs/source/en/features/gradient_clipping.md delete mode 100644 docs/source/en/features/gradient_handler.md delete mode 100644 docs/source/en/features/mixed_precision_training.md delete mode 100644 docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md delete mode 100644 docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md delete mode 100644 docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md delete mode 100644 docs/source/zh-Hans/basics/colotensor_concept.md delete mode 100644 docs/source/zh-Hans/basics/configure_parallelization.md delete mode 100644 docs/source/zh-Hans/basics/define_your_config.md delete mode 100644 docs/source/zh-Hans/basics/engine_trainer.md delete mode 100644 docs/source/zh-Hans/basics/initialize_features.md delete mode 100644 docs/source/zh-Hans/basics/model_checkpoint.md delete mode 100644 docs/source/zh-Hans/features/gradient_accumulation.md delete mode 100644 docs/source/zh-Hans/features/gradient_clipping.md delete mode 100644 docs/source/zh-Hans/features/gradient_handler.md delete mode 100644 docs/source/zh-Hans/features/mixed_precision_training.md diff --git a/docs/sidebars.json b/docs/sidebars.json index 8be40e4512f9..bf92e9755f4a 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -29,13 +29,7 @@ "basics/launch_colossalai", "basics/booster_api", "basics/booster_plugins", - "basics/booster_checkpoint", - "basics/define_your_config", - "basics/initialize_features", - "basics/engine_trainer", - "basics/configure_parallelization", - "basics/model_checkpoint", - "basics/colotensor_concept" + "basics/booster_checkpoint" ] }, { @@ -44,12 +38,8 @@ "collapsed": true, "items": [ "features/mixed_precision_training_with_booster", - "features/mixed_precision_training", "features/gradient_accumulation_with_booster", - "features/gradient_accumulation", "features/gradient_clipping_with_booster", - "features/gradient_clipping", - "features/gradient_handler", "features/zero_with_chunk", { "type": "category", @@ -75,10 +65,7 @@ "advanced_tutorials/train_vit_using_pipeline_parallelism", "advanced_tutorials/train_vit_with_hybrid_parallelism", "advanced_tutorials/train_gpt_using_hybrid_parallelism", - "advanced_tutorials/define_your_own_parallel_model", - "advanced_tutorials/add_your_parallel", "advanced_tutorials/meet_gemini", - "advanced_tutorials/parallelize_your_training_like_Megatron", "advanced_tutorials/integrate_mixture_of_experts_into_your_model", "advanced_tutorials/opt_service" ] diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md deleted file mode 100644 index 63434a526228..000000000000 --- a/docs/source/en/advanced_tutorials/add_your_parallel.md +++ /dev/null @@ -1,125 +0,0 @@ -# Add Your Own Parallel Mode - -Author: Shenggui Li, Yongbin Li - -**Prerequisite:** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - -## Introduction - -To enable researchers and engineers to extend our system to other novel large-scale distributed training algorithm -with less effort, we have decoupled various components in the training lifecycle. You can implement your own -parallelism by simply inheriting from the base class. - -The main components are: - -1. `ProcessGroupInitializer` -2. `GradientHandler` -3. `Schedule` - -**This currently requires some code to the source code, thus we recommend that you install from source with the `-e` flag. -`-e` flag makes the installation editable, thus, your code change will be reflected in your Python runtime. -We will work on this to avoid change to source code in future releases.** - - -## Process Group Initializer - -Parallelism is often managed by process groups where processes involved in the same parallel algorithm are placed in the same -process group. For different parallel algorithms, different process groups need to be created. Colossal-AI provides a -global context for users to easily manage their process groups. If you wish to add new process group, you can easily -define a new class and set it in your configuration file. To define your own way of creating process groups, you can -follow the steps below to create a new distributed initialization. - -1. Add your parallel mode in `colossalai.legacy.context.parallel_mode.ParallelMode`. - ```python - class ParallelMode(Enum): - GLOBAL = 'global' - DATA = 'data' - PIPELINE = 'pipe' - ... - - NEW_MODE = 'new_mode' # define your mode here - ``` - -2. Create a `ProcessGroupInitializer`. You can refer to examples given in `colossalai.context.dist_group_initializer`. The - first six arguments are fixed. `ParallelContext` will pass in these arguments for you. If you need to set other - arguments, you can add it behind like the `arg1, arg2` in the example below. Lastly, register your initializer to the - registry by adding the decorator `@DIST_GROUP_INITIALIZER.register_module`. - ```python - # sample initializer class - @DIST_GROUP_INITIALIZER.register_module - class MyParallelInitializer(ProcessGroupInitializer): - - def __init__(self, - rank: int, - world_size: int, - config: Config, - data_parallel_size: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - arg1, - arg2): - super().__init__(rank, world_size, config) - self.arg1 = arg1 - self.arg2 = arg2 - # ... your variable init - - def init_parallel_groups(self): - # initialize your process groups - pass - - ``` - - Then, you can insert your new initializer to the current mode-to-initialize mapping - in `colossalai.constants.INITIALIZER_MAPPING`. You can modify the file or insert new key-value pair dynamically. - - ```python - colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' - ``` - -3. Set your initializer in your config file. You can pass in your own arguments if there is any. This allows - the `ParallelContext` to create your initializer and initialize your desired process groups. - - ```python - parallel = dict( - pipeline=dict(size=1), - tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode - ) - ``` - -## Gradient Handler - -Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce -strategies may be executed for different kinds of parallelism, users can -inherit `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library -uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data -parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own -gradient handler like below: - -```python -from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.legacy.engine import BaseGradientHandler - -@GRADIENT_HANDLER.register_module -class YourGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - -``` - -Afterwards, you can specify the gradient handler you want to use in your configuration file. - -```python -gradient_handlers = [ - dict(type='YourGradientHandler'), -] -``` - -## Schedule - -Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline -schedules. If you want to modify how the forward and backward passes are executed, you can -inherit `colossalai.legacy.engine.schedule.BaseSchedule` and implement the `forward_back_step` function. - diff --git a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md deleted file mode 100644 index 8e48737d2f64..000000000000 --- a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md +++ /dev/null @@ -1,36 +0,0 @@ -# Define your own parallel model - -Author: Zhengda Bian, Yongbin Li - -> ⚠️ We are working on this documentation to make it more detailed. We will introduce the mechanism of different parallelism -> and how to use them to write a model. - -Let's say that you have a huge MLP model with billions of parameters and its extremely large hidden layer size makes it -impossible to fit into a single GPU directly. Don't worry, Colossal-AI is here to help you sort things out. With the help of Colossal-AI, -you can write your model in the familiar way in which you used to write models for a single GPU, while Colossal-AI automatically -splits your model weights and fit them perfectly into a set of GPUs. We give a simple example showing how to write a simple -2D parallel model in the Colossal-AI context. - -## Write a simple 2D parallel model - -```python -from colossalai.nn import Linear2D -import torch.nn as nn - -class MLP_2D(nn.Module): - - def __init__(self): - super().__init__() - self.linear_1 = Linear2D(in_features=1024, out_features=16384) - self.linear_2 = Linear2D(in_features=16384, out_features=1024) - - def forward(self, x): - x = self.linear_1(x) - x = self.linear_2(x) - return x -``` - -## Use pre-defined model - -For the sake of your convenience, we kindly provide you in our Model Zoo with some prevalent models such as *BERT*, *ViT*, *MoE*, -and *GPT*. Feel free to customize them into different sizes to fit into your special needs. diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md deleted file mode 100644 index 0a94a7f5d691..000000000000 --- a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ /dev/null @@ -1,194 +0,0 @@ -# Parallelize Your Training like Megatron-LM via ColoTensor - -Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) - -**Prerequisite:** -- [ColoTensor Concepts](../basics/colotensor_concept.md) - -## Introduction - -Thanks to the convenience given by ColoTensor, users can apply parallelism with the least edition to their serial code. -In this tutorial, we will illustrate how to modify the training model to automatically adapt the code to parallel training like Megatron-LM. -We take the GPT-2 model offered by HuggingFace as an example and provide a way for you to pre-train the GPT-2 model on a single GPU. - -Megatron-LM provided a profound paradigm to parallelize large transformer language models. -However, in order to train large transformer language models at scale, users have to build their models with those modules provided by Megatron. -It imposes several difficult jobs on users, such as loading the weights from the pre-trained models and constructing the parallelized models. -To mitigate users' trouble, we offer ColoTensor to enable the tensor model parallelism automatically. - -## Definitions of the model and the loss function - -First we use the GPTModel and GPTLoss directly from the HuggingFace library. - -```python -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -class GPTLMModel(nn.Module): - def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, - n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids, attention_mask): - # Only return lm_logits - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -``` - -## Brief Review of GPT-2 - -Now, we recall the structure of each GPT-2 model. -Every GPT-2 model can be represented as a DAG. -As shown in the below pictures, each circle represents an operator and each square represents a weight. -An arrow indicates the flow of the input data, and the notation alongside the arrow demonstrates the shape of the input data. - -Then, let's take an insight into this GPT-2 model. It consists of three parts. -They are the **embedding module**, **transformer layers**, and the **classification head**. - -The embedding module contains two weights, token embedding weight and position embedding weight. -After the forward operation of the embedding module, each word in all sequences of the raw input data will be embedded into a hidden state. - -
- -
The embedding module
-
- -Each transformer layer contains two blocks. The self-attention operation is called in the first block and a two-layer perception is located in the second block. - -
- -
The transformer layer
-
- -In the end, the classification head is just a linear module without bias, which only has a weight inside. - -## Applied with ColoTensor - -Two steps make your serial code adapted to Megatron-LM tensor parallel style. -1. Initialize the model in the context of ColoInitContext. -2. Setting ColoTensorSpec for each parameter. - -### Initialize with ColoInitContext - -We should build the model in the ColoInitContext. -In this context, any parameter initialized would be transformed to ColoParameter and moved to the corresponded device automatically. - -```python -from colossalai.utils.model.colo_init_context import ColoInitContext - -with ColoInitContext(device=torch.device('cpu')): - model = GPTLMModel() -``` - -### Setting ColoTensorSpec for each parameter - -After the creation of the model, we establish the distributed environment through ProcessGroup. -Here, we specify the degree of the tensor parallelism as the same as the number of all GPUs, which means the degree of data parallelism is 1. - -```python -import torch.distributed as dist -from colossalai.tensor import ProcessGroup - -pg = ProcessGroup(tp_degree=dist.get_world_size()) -``` - -Now, some auxiliary functions are necessary for the next step. We define two functions to split a parameter. -Megatron-LM-like tensor parallelism requires splitting a parameter tensor along its first dimension or its last dimension. - -```python -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup - -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -Then we adapt the model to the tensor parallelism. -According to the tensor parallelism applied in Megatron, it is supposed to shard along the last dimension of tensors, including the weights of token embedding, position embedding, all linear weights and biases in self-attention blocks, the first weight linear and bias in each MLP. -And it shards the second linear weight along its first dimension. - -```python -for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # set process group for all parameters - param.set_process_group(pg) - - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # column slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # column slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # column slice -``` - -The modified model is illustrated below. - -The embedding module: - -
- -
The modified embedding module
-
- -The transformer layers: - -
- -
The modified transformer layer
-
- -Once users have specified the distributed pattern of each parameter, ColoTensor is capable of inferring the computation patterns of all operators, including matrix multiplication, the linear function, other elementwise functions in torch.nn.functional, etc. -In this way, users can train their models as usual. - -In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): - from colossalai.zero import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_m=32) - return model -``` - -## Pretrain GPT-2 On Single GPU - -The above optimization we made allows us to pretrain the GPT-2 model on a single GPU. We only need to set the parameter `GPUNUM`=1 in `run.sh`, and then we can complete the model training on a single GPU when running the file. - -The GPT-2 example is accessible at [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). - - diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md deleted file mode 100644 index abe470fe0794..000000000000 --- a/docs/source/en/basics/colotensor_concept.md +++ /dev/null @@ -1,98 +0,0 @@ -# ColoTensor Concepts - -Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) - -> ⚠️ The information on this page is outdated and will be deprecated. - -**Prerequisite:** -- [Colossal-AI Overview](../concepts/colossalai_overview.md) -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) - -## Introduction - -After ColossalAI version 0.1.8, [ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) becomes the basic data structure for tensors in ColossalAI. It is a subclass of torch.Tensor and can be used as a PyTorch Tensor. Additionally, some unique features make it possible to represent a Global Tensor with a payload distributed across multiple GPU devices. With the help of ColoTensor, the users can write distributed DNN training program similar to a serial one.support the following features. - -ColoTensor contains extra attributes capsuled in a [ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) instance to describe the tensor's payload distribution and computing pattern. - -- ProcessGroup: how processes are organized as communication groups. -- Distributed Spec: how tensor is distributed among process groups. -- Compute Spec: how the tensor is used during computation. - -We elaborate on them one by one. - -## ProcessGroup - -An instance of class [ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) describes how processes are organized in process groups. Processes in a process group can participate in the same collective communication operations together, such as allgather, allreduce, etc. The way the process group is organized is dominated by the Tensor's parallelism strategy. For example, if the user defines the tensor parallel (TP) and data parallel (DP) modes of a tensor, then the process organization of the process group will be automatically deduced. The process group settings can vary among different tensors. Therefore, it enables us to support more complicated hybrid parallel. The pipeline parallel (PP) definition is not in the ProcessGroup, it needs another set of mechanisms . We will supplement the related content of ColoTensor applied to PP in the future. - -Currently, a process group of ColoTensor is defined by two configurations, i.e. tp_degree and dp_degree. In the case of DP+TP hybrid parallelism, the device can be viewed as a 2D mesh. We place TP communication groups on the leading low dimension of the device mesh and then place the data parallel groups along the high dimension of the device mesh. The reason is that tensor parallelism has a larger communication overhead than data parallelism. Neighboring devices are placed inside a TP process group and are often placed in the same node. - -Considering that 8 processes are configured as tp_degree=4, and dp_degree=2, the layout is shown below. Process group tp0 contains gpu 0,1,2,3. Process dp1 contains gpu 1 and 5. - -
- -
Process Group using tp_degree=4, dp_degree=2
-
- -## Distributed Spec - -An instance of [Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html) describes how a ColoTensor is distributed among the ProcessGroup. - -How tensors are distributed among DP process groups is automatically derived and does not need to be manually specified by the user. If this tensor is a model parameter, it is replicated within the DP process group. If it is an activation tensor, it is split along the process with the highest dimension and evenly distributed the tensor payload among processes in the DP process group. - -Therefore, when using Distributed Spec, we only need to describe the way that the tensor is distributed among TP process groups. There are currently two ways to distribute among TP process group, i.e. [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec) and [ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec). ShardSpec needs to specify the dimension index dim of the partition and the number of partitions num_partitions. Currently, we only support the split on a single dim. Different dist specs on the TP process groups can be converted to each other through the set_dist_spec() interface. The spec conversions are recorded by the autograd mechanism and it will trigger corresponding reverse operations during backward propagation. - -## Compute Spec - -An instance of class [ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec) describes how a Colotensor be used in DNN training. Currently, we will set the correct Compute Pattern for the ColoTensor as the parameters of the module. The specific application scenarios will be shown in the next document. - -## ColoParameter - -[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter) is a subclass of ColoTensor. Used to define a Global Parameter tensor. Its relationship with ColoTensor is consistent with Torch.Tensor and torch.Parameter. The latter allows the tensor to appear in the return values of the module's parameters() and name_parameters() methods. - -## Example - -Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp_degree=4, dp_degree=2. And then the tensor is sharded along the last dim among the TP process groups. Finally, we reshard it along the first dim (0 dim) among the TP process groups. We encourage users to run the code and observe the shape of each tensor. - - -```python -import torch -import torch.multiprocessing as mp -from colossalai.utils import print_rank_0 -from functools import partial - -import colossalai -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.testing import spawn - -import torch - -def run_dist_tests(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=2, dp_degree=2) - - torch.manual_seed(0) - local_tensor = torch.randn(2, 3, 1).cuda() - print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") - - spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - t1 = ColoTensor.from_torch_tensor(local_tensor, spec) - t1 = t1.to_replicate() - print_rank_0(f"shape {t1.shape}, {t1.data}") - - spec2 = ShardSpec([0], [pg.tp_world_size()]) - t1.set_dist_spec(spec2) - print_rank_0(f"shape {t1.shape}, {t1.data}") - -def test_dist_cases(world_size): - spawn(run_dist_tests, world_size) - -if __name__ == '__main__': - test_dist_cases(4) -``` - -:::caution - -The ColoTensor is an experimental feature and may be updated. - -::: diff --git a/docs/source/en/basics/configure_parallelization.md b/docs/source/en/basics/configure_parallelization.md deleted file mode 100644 index fd1e72ccd45a..000000000000 --- a/docs/source/en/basics/configure_parallelization.md +++ /dev/null @@ -1,158 +0,0 @@ -# Configure Parallelization - -Author: Shenggui Li, Siqi Mai - -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Booster Plugins](../basics/booster_plugins.md) for more information. - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) -- [Define Your Configuration](./define_your_config.md) - - -## Introduction - -We support multiple parallelization in Colossal-AI. Hybrid parallelism in our codebase refers to namely the combination -of data parallelism, pipeline parallelism and tensor parallelism (1D, 2D, 2.5D, 3D). - -Each parallelism requires different network topology and thus initialize different process groups. -You can initialize the corresponding process group by setting `parallel` in the config file. -The configuration for `parallel` must obey the following format. Data parallel size will be -inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. -`colossalai.launch` will initialize these distributed process groups automatically based on your configuration. - -Some sample configurations are shown below: - -```python -# sampler format -parallel = dict( - pipeline=dict("size": int), - tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) -) - -# this is ok -parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=4, mode='2d') -) - -# this is ok -parallel = dict( - pipeline=2, - tensor=dict(size=4, mode='2d') -) - -# this is not ok -# as you need to specify the mode for tensor parallelism -parallel = dict( - pipeline=2, - tensor=4 -) - -# this is ok as well as tensor will be default to size 1 -# and mode None -parallel = dict( - pipeline=2 -) - -# this is ok as well as pipeline will default to size 1 -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -``` - -The key name `size` refers to the parallel size of the parallelism dimension. For example, pipeline size 2 means there -will be 2 pipeline stages. The key name `mode` in tensor parallel config means the corresponding tensor parallelism -will be initialized. - -**You can choose to not have 'parallel' in your configuration and both pipeline and tensor will default to size 1.** - -**Total number of GPUs must be equal to `data parallel size * tensor parallel size * pipeline parallel size`** - -## Data Parallel - -Data parallel is the most common way to distribute your training task by splitting data into several shards and train on -a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not -have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI. - -1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers -2. Otherwise, PyTorch DistributedDataParallel will be used - -In most cases, you will be using the second mode unless you have complex handling of the gradients. - -## 1D, 2D, 2.5D and 3D Parallel - -To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each -tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI. - -- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) - -- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) - 2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer - outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of `P = N^2` devices where - `N` is the number of tensor chunks in a single dimension. - -- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) - Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which - further parallelizes 2D tensor parallelism. An amount of `P = N^2 ∗ d` processors are arranged into `d` layers, where - each layer performs matrix multiplication operations independently with a dimension `N`. - -- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) - We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method - achieves the optimal, `O(P^{1/3})` communication overhead on $P$ processors, while both computation and memory usage - are evenly distributed through optimized load balancing of parameters as well as activations. - -```python -# 1D parallel -parallel = dict( - tensor=dict(size=4, mode='1d') -) - -# 2D parallel -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -# 2.5D parallel -parallel = dict( - tensor=dict(size=8, mode='2.5d', depth=2) -) - -# 3D parallel -parallel = dict( - tensor=dict(size=8, mode='3d') -) -``` - -Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed -operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction. - - -## Pipeline Parallel - -Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple -model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU -and the second layer to the second GPU. - -You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI -will automatically creates the pipeline schedule which defines the forward and backward step. - -```python -parallel = dict( - pipeline=dict(size=4), # number of pipeline stages -) -``` - -## Sequence Parallel - -Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging. -This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120). -You can use specify the mode to be `sequence` to initialize its process group. - - -```python -parallel = dict( - tensor=dict(size=4, mode='sequence') -) -``` diff --git a/docs/source/en/basics/define_your_config.md b/docs/source/en/basics/define_your_config.md deleted file mode 100644 index 048ffcacbb8f..000000000000 --- a/docs/source/en/basics/define_your_config.md +++ /dev/null @@ -1,85 +0,0 @@ -# Define Your Configuration - -Author: Guangyang Lu, Shenggui Li, Siqi Mai - -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Booster API](../basics/booster_api.md) for more information. - - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Colossal-AI Overview](../concepts/colossalai_overview.md) - - -## Introduction - -In Colossal-AI, a configuration file is required to specify the features the system will inject into the training process. -In this tutorial, we will introduce you how to construct your configuration file and how this config file will be used. -Using configuration file has several advantages: - -1. You can store your feature configuration and training hyper-parameters in different configuration files -2. New features released in the future can be specified in the configuration without code change in the training script - -In this tutorial, we will cover how to define your configuration file. - -## Configuration Definition - -In a configuration file, there are two types of variables. One serves as feature specification and the other serves -as hyper-parameters. All feature-related variables are reserved keywords. For example, if you want to use mixed precision -training, you need to use the variable name `fp16` in the config file and follow a pre-defined format. - -### Feature Specification - -There is an array of features Colossal-AI provides to speed up training. Each feature is defined by a corresponding field -in the config file. In this tutorial, we are not giving the config details for all the features, but rather we are providing -an illustration of how to specify a feature. **The details of each feature can be found in its respective tutorial.** - -To illustrate the use of config file, we use mixed precision training as an example here. In order to do so, you need to -follow the steps below. - -1. create a configuration file (e.g. `config.py`, the file name can be anything) -2. define the mixed precision configuration in the config file. For example, in order to use mixed precision training -natively provided by PyTorch, you can just write these lines of code below into your config file. - - ```python - from colossalai.amp import AMP_TYPE - - fp16 = dict( - mode=AMP_TYPE.TORCH - ) - ``` - -3. Tell Colossal-AI where your config file is when launch the distributed environment. For example, the config file is in -the current directory. - - ```python - import colossalai - - colossalai.launch(config='./config.py', ...) - ``` - -In this way, Colossal-AI knows what features you want to use and will inject this feature during `colossalai.initialize`. - -### Global Hyper-parameters - -Besides feature specification, the config file can also serve as a place to define your training hyper-parameters. This -comes handy when you want to perform multiple experiments, each experiment details can be put into a single config file -to avoid confusion. These parameters will be stored in the global parallel context and can be accessed in the training script. - -For example, you can specify the batch size in your config file. - -```python -BATCH_SIZE = 32 -``` - -After launch, you are able to access your hyper-parameters through global parallel context. - -```python -import colossalai -from colossalai.core import global_context as gpc - -colossalai.launch(config='./config.py', ...) - -# access your parameter -print(gpc.config.BATCH_SIZE) - -``` diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md deleted file mode 100644 index e17c37e24a55..000000000000 --- a/docs/source/en/basics/engine_trainer.md +++ /dev/null @@ -1,390 +0,0 @@ -# Use Engine and Trainer in Training - -Author: Shenggui Li, Siqi Mai - -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Booster API](../basics/booster_api.md) for more information. - -**Prerequisite:** -- [Initialize Features](./initialize_features.md) - -## Introduction - -In this tutorial, you will learn how to use the engine and trainer provided in Colossal-AI to train your model. -Before we delve into the details, we would like to first explain the concept of engine and trainer. - -### Engine - -Engine is essentially a wrapper class for model, optimizer and loss function. -When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with -functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. -An engine object will use similar APIs to those of PyTorch training components such that the user has minimum change -to their code. - -Below is a table which shows the commonly used APIs for the engine object. - -| Component | Function | PyTorch | Colossal-AI | -| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | -| optimizer | Set all gradients to zero before an iteration | optimizer.zero_grad() | engine.zero_grad() | -| optimizer | Update the parameters | optimizer.step() | engine.step() | -| model | Run a forward pass | outputs = model(inputs) | outputs = engine(inputs) | -| criterion | Calculate the loss value | loss = criterion(output, label) | loss = engine.criterion(output, label) | -| criterion | Execute back-propagation on the model | loss.backward() | engine.backward(loss) | - -The reason why we need such an engine class is that we can add more functionalities while hiding the implementations in -the `colossalai.initialize` function. -Imaging we are gonna add a new feature, we can manipulate the model, optimizer, dataloader and loss function in the -`colossalai.initialize` function and only expose an engine object to the user. -The user only needs to modify their code to the minimum extent by adapting the normal PyTorch APIs to the Colossal-AI -engine APIs. In this way, they can enjoy more features for efficient training. - -A normal training iteration using engine can be: - -```python -import colossalai - -# build your model, optimizer, criterion, dataloaders -... - -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -for img, label in train_dataloader: - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() -``` - -### Trainer - -Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. However, in pursuit of more abstraction, it loses some flexibility compared to engine. The trainer is designed to execute a forward and backward step to perform model weight update. It is easy to create a trainer object by passing the engine object. The trainer has a default value `None` for the argument `schedule`. In most cases, we leave this value to `None` unless we want to use pipeline parallelism. If you wish to explore more about this parameter, you can go to the tutorial on pipeline parallelism. - -```python -from colossalai.logging import get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks - -# build components and initialize with colossalai.initialize -... - -# create a logger so that trainer can log on the console -logger = get_dist_logger() - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) -``` - - - -In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages depending on whether the user wants to update the learning rate after each training iteration or only after the entire training epoch. You can store the hook objects in a list and pass it to `trainer.fit` method. `trainer.fit` method will execute training and testing based on your parameters. If `display_process` is True, a progress bar will be displayed on your console to show the training process. - -```python -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), -] - -# start training -trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - -If you want to customize your own hook class, you can inherit `hooks.BaseHook` and override the life-cycle methods of your interest. A dummy example to demonstrate how to create a simple log message hook is provided below for your reference. - -```python -from colossalai.logging import get_dist_logger -from colossalai.legacy.trainer import hooks - -class LogMessageHook(hooks.BaseHook): - - def __init__(self, priority=10): - self._logger = get_dist_logger() - - def before_train(self, trainer): - self._logger.info('training starts') - - def after_train(self, trainer): - self._logger.info('training finished') - - -... - -# then in your training script -hook_list.append(LogMessageHook()) -``` - - - -In the sections below, I will guide you through the steps required to train a ResNet model with both engine and trainer. - - - -## Explain with ResNet - -### Overview - -In this section we will cover: - -1. Use an engine object to train a ResNet34 model on CIFAR10 dataset -2. Use a trainer object to train a ResNet34 model on CIFAR10 dataset - -The project structure will be like: - -```bash --- config.py --- run_resnet_cifar10_with_engine.py --- run_resnet_cifar10_with_trainer.py -``` - -Steps 1-4 below are commonly used regardless of using engine or trainer. Thus, steps 1-4 + step 5 will be your `run_resnet_cifar10_with_engine.py` and steps 1-4 + step 6 will form `run_resnet_cifar10_with_trainer.py`. - -### Hands-on Practice - -#### Step 1. Create a Config File - -In your project folder, create a `config.py`. This file is to specify some features you may want to use to train your model. A sample config file is as below: - -```python -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -NUM_EPOCHS = 200 - -fp16=dict( - mode=AMP_TYPE.TORCH -) -``` - -In this config file, we specify that we want to use batch size 128 per GPU and run for 200 epochs. These two parameters are exposed by `gpc.config`. For example, you can use `gpc.config.BATCH_SIZE` to access the value you store in your config file. The `fp16` configuration tells `colossalai.initialize` to use mixed precision training provided by PyTorch to train the model with better speed and lower memory consumption. - -#### Step 2. Initialize Distributed Environment - -We need to initialize the distributed training environment. This has been introduced in the tutorial on how to -[launch Colossal-AI](./launch_colossalai.md). For this demonstration, we use `launch_from_torch` and PyTorch launch utility. - -```python -import colossalai - -# ./config.py refers to the config file we just created in step 1 -colossalai.launch_from_torch(config='./config.py') -``` - -#### Step 3. Create all the training components - -In this step, we can create all the components used for training. These components include: - -1. Model -2. Optimizer -3. Criterion/loss function -4. Training/Testing dataloaders -5. Learning rate Scheduler -6. Logger - - - -To build these components, you need to import the following modules: - -```python -from pathlib import Path -from colossalai.logging import get_dist_logger -import torch -import os -from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader -from torchvision import transforms -from colossalai.nn.lr_scheduler import CosineAnnealingLR -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet34 -``` - - - -Then build your components in the same way as how to normally build them in your PyTorch scripts. In the script below, we set the root path for CIFAR10 dataset as an environment variable `DATA`. You can change it to any path you like, for example, you can change `root=Path(os.environ['DATA'])` to `root='./data'` so that there is no need to set the environment variable. - -```python -# build logger -logger = get_dist_logger() - -# build resnet -model = resnet34(num_classes=10) - -# build datasets -train_dataset = CIFAR10( - root='./data', - download=True, - transform=transforms.Compose( - [ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -test_dataset = CIFAR10( - root='./data', - train=False, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -# build dataloaders -train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -test_dataloader = get_dataloader(dataset=test_dataset, - add_sampler=False, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -# build criterion -criterion = torch.nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) - -# lr_scheduler -lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) -``` - -#### Step 4. Initialize with Colossal-AI - -Next, the essential step is to obtain the engine class by calling `colossalai.initialize`. As stated in `config.py`, we will be using mixed precision training for training ResNet34 model. `colossalai.initialize` will automatically check your config file and assign relevant features to your training components. In this way, our engine object has already been able to train with mixed precision, but you do not have to explicitly take care of it. - -```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader, - ) -``` - - - -#### Step 5. Train with engine - -With all the training components ready, we can train ResNet34 just like how to normally deal with PyTorch training. - -```python -for epoch in range(gpc.config.NUM_EPOCHS): - # execute a training iteration - engine.train() - for img, label in train_dataloader: - img = img.cuda() - label = label.cuda() - - # set gradients to zero - engine.zero_grad() - - # run forward pass - output = engine(img) - - # compute loss value and run backward pass - train_loss = engine.criterion(output, label) - engine.backward(train_loss) - - # update parameters - engine.step() - - # update learning rate - lr_scheduler.step() - - # execute a testing iteration - engine.eval() - correct = 0 - total = 0 - for img, label in test_dataloader: - img = img.cuda() - label = label.cuda() - - # run prediction without back-propagation - with torch.no_grad(): - output = engine(img) - test_loss = engine.criterion(output, label) - - # compute the number of correct prediction - pred = torch.argmax(output, dim=-1) - correct += torch.sum(pred == label) - total += img.size(0) - - logger.info( - f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) -``` - -#### Step 6. Train with trainer - -If you wish to train with a trainer object, you can follow the code snippet below: - -```python -from colossalai.legacy.nn.metric import Accuracy -from colossalai.legacy.trainer import Trainer, hooks - - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) - -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LogMemoryByEpochHook(logger) -] - -# start training -# run testing every 1 epoch -trainer.fit( - train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - - - -#### Step 7. Start Distributed Training - -Lastly, we can invoke the scripts using the distributed launcher provided by PyTorch as we used `launch_from_torch` in Step 2. You need to replace `` with the number of GPUs available on your machine. This number can be 1 if you only want to use 1 GPU. If you wish to use other launchers, you can refer to the tutorial on How to Launch Colossal-AI. - -```bash -# with engine -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -# with trainer -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py -``` - diff --git a/docs/source/en/basics/initialize_features.md b/docs/source/en/basics/initialize_features.md deleted file mode 100644 index b89017427476..000000000000 --- a/docs/source/en/basics/initialize_features.md +++ /dev/null @@ -1,51 +0,0 @@ -# Initialize Features - -Author: Shenggui Li, Siqi Mai - -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Booster API](../basics/booster_api.md) for more information. - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Colossal-AI Overview](../concepts/colossalai_overview.md) - -## Introduction - -In this tutorial, we will cover the use of `colossalai.initialize` which injects features into your training components -(e.g. model, optimizer, dataloader) seamlessly. Calling `colossalai.initialize` is the standard procedure before you run -into your training loops. - -In the section below, I will cover how `colossalai.initialize` works and what we should take note of. - -## Usage - -In a typical workflow, we will launch distributed environment at the beginning of our training script. -Afterwards, we will instantiate our objects such as model, optimizer, loss function, dataloader etc. At this moment, `colossalai.initialize` -can come in to inject features into these objects. A pseudo-code example is like below: - -```python -import colossalai -import torch -... - - -# launch distributed environment -colossalai.launch(config='./config.py', ...) - -# create your objects -model = MyModel() -optimizer = torch.optim.Adam(model.parameters(), lr=0.001) -criterion = torch.nn.CrossEntropyLoss() -train_dataloader = MyTrainDataloader() -test_dataloader = MyTrainDataloader() - -# initialize features -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -``` - -The `colossalai.initialize` function will return an `Engine` object. The engine object is a wrapper -for model, optimizer and loss function. **The engine object will run with features specified in the config file.** -More details about the engine can be found in the [Use Engine and Trainer in Training](./engine_trainer.md). diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md deleted file mode 100644 index c3ba5b04bca2..000000000000 --- a/docs/source/en/basics/model_checkpoint.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Checkpoint - -Author : Guangyang Lu - -> ⚠️ The information on this page is outdated and will be deprecated. Please check [Booster Checkpoint](../basics/booster_checkpoint.md) for more information. - -**Prerequisite:** -- [Launch Colossal-AI](./launch_colossalai.md) -- [Initialize Colossal-AI](./initialize_features.md) - -**Example Code:** -- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) - -**This function is experiential.** - -## Introduction - -In this tutorial, you will learn how to save and load model checkpoints. - -To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing. - -Moreover, when loading, you are not demanded to use the same parallel strategy as saving. - -## How to use - -### Save - -There are two ways to train a model in Colossal-AI, by engine or by trainer. -**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first. - -#### Save when using engine - -```python -from colossalai.utils import save_checkpoint -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -for epoch in range(num_epochs): - ... # do some training - save_checkpoint('xxx.pt', epoch, model) -``` - -#### Save when using trainer -```python -from colossalai.legacy.trainer import Trainer, hooks -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -trainer = Trainer(engine, ...) -hook_list = [ - hooks.SaveCheckpointHook(1, 'xxx.pt', model) - ...] - -trainer.fit(... - hook=hook_list) -``` - -### Load - -```python -from colossalai.utils import load_checkpoint -model = ... -load_checkpoint('xxx.pt', model) -... # train or test -``` - diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md index 0f01cfd325e5..37c01db31342 100644 --- a/docs/source/en/features/1D_tensor_parallel.md +++ b/docs/source/en/features/1D_tensor_parallel.md @@ -2,10 +2,6 @@ Author: Zhengda Bian, Yongbin Li -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - **Example Code** - [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md index c79e7d196f8b..692e2702edd9 100644 --- a/docs/source/en/features/2D_tensor_parallel.md +++ b/docs/source/en/features/2D_tensor_parallel.md @@ -3,8 +3,6 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) **Example Code** diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md index b3cbd1c7c727..4a97a39e1eff 100644 --- a/docs/source/en/features/2p5D_tensor_parallel.md +++ b/docs/source/en/features/2p5D_tensor_parallel.md @@ -3,8 +3,6 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) - [2D Tensor Parallelism](./2D_tensor_parallel.md) diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md index 00e6c5fca40c..8f7deb5b6b74 100644 --- a/docs/source/en/features/3D_tensor_parallel.md +++ b/docs/source/en/features/3D_tensor_parallel.md @@ -3,8 +3,6 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) - [2D Tensor Parallelism](./2D_tensor_parallel.md) diff --git a/docs/source/en/features/gradient_accumulation.md b/docs/source/en/features/gradient_accumulation.md deleted file mode 100644 index 91d89b815bf7..000000000000 --- a/docs/source/en/features/gradient_accumulation.md +++ /dev/null @@ -1,47 +0,0 @@ -# Gradient Accumulation (Outdated) - -Author: Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) - -## Introduction - -Gradient accumulation is a common way to enlarge your batch size for training. -When training large-scale models, memory can easily become the bottleneck and the batch size can be very small, (e.g. 2), -leading to unsatisfactory convergence. Gradient accumulation works by adding up the gradients calculated in multiple iterations, -and only update the parameters in the preset iteration. - -## Usage - -It is simple to use gradient accumulation in Colossal-AI. Just add this following configuration into your config file. -The integer represents the number of iterations to accumulate gradients. - -```python -gradient_accumulation = -``` - -## Hands-on Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) -to demonstrate gradient accumulation. In this example, we set the gradient accumulation size to be 4. You can run the script using this command: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -``` - -You will see output similar to the text below. This shows gradient is indeed accumulated as the parameter is not updated -in the first 3 steps, but only updated in the last step. - -```text -iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) -``` - - diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md index 7bc4eb47bcd7..347cd6e519bb 100644 --- a/docs/source/en/features/gradient_accumulation_with_booster.md +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -1,9 +1,8 @@ -# Gradient Accumulation (Latest) +# Gradient Accumulation Author: [Mingyan Jiang](https://github.com/jiangmingyan) **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) - [Training Booster](../basics/booster_api.md) ## Introduction diff --git a/docs/source/en/features/gradient_clipping.md b/docs/source/en/features/gradient_clipping.md deleted file mode 100644 index 5a23c68e3e27..000000000000 --- a/docs/source/en/features/gradient_clipping.md +++ /dev/null @@ -1,64 +0,0 @@ -# Gradient Clipping (Outdated) - -Author: Boxiang Wang, Haichen Huang, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) - -**Related Paper** -- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) - -## Introduction - -In order to speed up training process and seek global optimum for better performance, more and more learning -rate schedulers have been proposed. People turn to control learning rate to adjust descent pace during training, -which makes gradient vector better to be uniformed in every step. In that case, the descent pace can be -controlled as expected. As a result, gradient clipping, a technique which can normalize the gradient vector -to circumscribe it in a uniformed length, becomes indispensable for those who desire their better -performance of their models. - -You do not have to worry about implementing gradient clipping when using Colossal-AI, we support gradient -clipping in a powerful and convenient way. All you need is just an additional command in your configuration -file. - -## Why you should use gradient clipping provided by Colossal-AI - -The reason of why we do not recommend users to write gradient clipping by themselves is that naive gradient clipping -may fail when applying tensor parallelism, pipeline parallelism or MoE. - -According to the illustration below, each GPU only owns a portion of parameters of the weight in a linear layer. -To get correct norm of gradient vector of the weight of the linear layer, the norm of every gradient vector in each GPU -should be summed together. -More complicated thing is that the distribution of bias is different from the distribution of the weight. -The communication group is different in the sum operation. - -(PS: This situation is an old version of 2D parallelism, the implementation in the code is not the same. -But it is a good example about the difficulty to unify all communication in gradient clipping.) - -
- -
Layout of parameters
-
- -Do not worry about it, since Colossal-AI have handled it for you. - -### Usage -To use gradient clipping, you can just simply add gradient clipping norm in your configuration file. -```python -clip_grad_norm = 1.0 -``` - -### Hands-On Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) -to demonstrate gradient clipping. In this example, we set the gradient clipping vector norm to be 1.0. You can run the script using this command: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py -``` - - diff --git a/docs/source/en/features/gradient_clipping_with_booster.md b/docs/source/en/features/gradient_clipping_with_booster.md index 341a608a5c7b..14eee67bc019 100644 --- a/docs/source/en/features/gradient_clipping_with_booster.md +++ b/docs/source/en/features/gradient_clipping_with_booster.md @@ -1,9 +1,8 @@ -# Gradient Clipping (Latest) +# Gradient Clipping Author: [Mingyan Jiang](https://github.com/jiangmingyan) **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) - [Training Booster](../basics/booster_api.md) **Related Paper** diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md deleted file mode 100644 index 66e5e3a9dfbd..000000000000 --- a/docs/source/en/features/gradient_handler.md +++ /dev/null @@ -1,64 +0,0 @@ -# Gradient Handler - -Author: Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) - -## Introduction - -In distributed training, gradient synchronization is required at the end of each iteration. This is important because we -need to make sure the parameters are updated with the same gradients in different machines so that the resulting parameters -are the same. This is often seen in data parallel as the model is replicated across data parallel ranks. - -In Colossal-AI, we provide an interface for users to customize how they want to handle the synchronization. This brings -flexibility in cases such as implementing a new parallelism method. - -When gradient handlers are used, PyTorch `DistributedDataParallel` will not be used as it will synchronize automatically. - -## Customize Your Gradient Handlers - -To implement a customized gradient handler, you need to follow these steps. -1. inherit `BaseGradientHandler` in Colossal-AI. -2. register the gradient handler into the `GRADIENT_HANDLER`. -3. implement `handle_gradient` method. - -```python -from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.legacy.engine.gradient_handler import BaseGradientHandler - - -@GRADIENT_HANDLER.register_module -class MyGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - - -``` - - -## Usage - -To use a gradient handler, you need to specify your gradient handler in the config file. The gradient handler -will be automatically built and attached to the engine. - -```python -gradient_handler = [dict(type='MyGradientHandler')] -``` - - -### Hands-On Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) -to demonstrate the use of gradient handler. In this example, we used `DataParallelGradientHandler` instead of PyTorch -`DistributedDataParallel` for data parallel training. - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py -``` - diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md deleted file mode 100644 index 164b2a21598c..000000000000 --- a/docs/source/en/features/mixed_precision_training.md +++ /dev/null @@ -1,368 +0,0 @@ -# Auto Mixed Precision Training (Outdated) - -Author: Chuanrui Wang, Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) - -**Related Paper** -- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) - - -## Introduction - -AMP stands for automatic mixed precision training. -In Colossal-AI, we have incorporated different implementations of mixed precision training: - -1. torch.cuda.amp -2. apex.amp -3. naive amp - - -| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent | -| ----------- | ----------------------- | ------------------------- | ----------- | -| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | -| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | - -The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex. -The last method is similar to Apex O2 level. -Among these methods, apex AMP is not compatible with tensor parallelism. -This is because that tensors are split across devices in tensor parallelism, thus, it is required to communicate among different processes to check if inf or nan occurs in the whole model weights. -We modified the torch amp implementation so that it is compatible with tensor parallelism now. - -> ❌️ fp16 and zero configuration are not compatible -> -> ⚠️ Pipeline only support naive AMP currently - -We recommend you to use torch AMP as it generally gives better accuracy than naive AMP if no pipeline is used. - -## Table of Contents - -In this tutorial we will cover: - -1. AMP introduction -2. AMP in Colossal-AI -3. Hands-on Practice - -## AMP Introduction - -Automatic Mixed Precision training is a mixture of FP16 and FP32 training. - -Half-precision float point format (FP16) has lower arithmetic complexity and higher compute efficiency. -Besides, fp16 requires half of the storage needed by fp32 and saves memory & network bandwidth, which makes more memory -available for large batch size and model size. - -However, there are other operations, like reductions, which require the dynamic range of fp32 to avoid numeric overflow/underflow. That's the reason why we introduce automatic mixed precision, attempting to match each operation to its appropriate data type, which can reduce the memory footprint and augment training efficiency. - -
- -
Illustration of an ordinary AMP (figure from PatrickStar paper)
-
- -## AMP in Colossal-AI - -We supported three AMP training methods and allowed the user to train with AMP with no code. You can just simply add `fp16` -configuration in your configuration file to use AMP. - - -```python -from colossalai.amp import AMP_TYPE - -# use Torch AMP -fp16=dict( - mode = AMP_TYPE.TORCH -) - -# use naive AMP -fp16=dict( - mode = AMP_TYPE.NAIVE -) - -# use NVIDIA Apex AMP -fp16=dict( - mode = AMP_TYPE.APEX -) - -``` - -> These are the minimum configuration, full configuration are stated in the section later - -### AMP Modularity - -AMP module is designed to be completely modular and can be used independently. -If you wish to only use AMP in your code base without `colossalai.initialize`, -you can use `colossalai.amp.convert_to_amp`. - -```python -from colossalai.amp import AMP_TYPE - -# example of using torch amp -model, optimizer, criterion = colossalai.amp.convert_to_amp(model, - optimizer, - criterion, - AMP_TYPE.TORCH) -``` - -### Torch AMP Configuration - -```python -from colossalai.amp import AMP_TYPE - -fp16=dict( - mode=AMP_TYPE.TORCH, - - # below are default values for grad scaler - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True -) -``` - -With optional arguments: -- init_scale(float, optional, default=2.**16): Initial scale factor -- growth_factor(float, optional, default=2.0): Factor by which the scale is multiplied during `update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. -- backoff_factor(float, optional, default=0.5): Factor by which the scale is multiplied during `update` if inf/NaN gradients occur in an iteration. -- growth_interval(int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. -- enabled(bool, optional, default=True): If ``False``, disables gradient scaling. `step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. - -### Apex AMP Configuration - -For this mode, we rely on the Apex implementation for mixed precision training. -We support this plugin because it allows for finer control on the granularity of mixed precision. -For example, O2 level (optimization level 2) will keep batch normalization in fp32. - -If you look for more details, please refer to [Apex Documentation](https://nvidia.github.io/apex/). - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.APEX, - - # below are the default values - enabled=True, - opt_level='O1', - cast_model_type=None, - patch_torch_functions=None, - keep_batchnorm_fp32=None, - master_weights=None, - loss_scale=None, - cast_model_outputs=None, - num_losses=1, - verbosity=1, - min_loss_scale=None, - max_loss_scale=16777216.0 -) -``` - -Parameters: -- enabled(bool, optional, default=True): If False, renders all AMP calls no-ops, so your script should run as if Amp were not present. - -- opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. -Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation. - -- num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. -When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per -loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple -losses/backward passes, but use a single global loss scale for all of them. - -- verbosity(int, default=1): Set to 0 to suppress Amp-related output. - -- min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. -The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored. - -- max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss -scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. - -Currently, the under-the-hood properties that govern pure or mixed precision training are the following: -cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. -They are optional properties override once opt_level is determined - -- cast_model_type: Casts your model’s parameters and buffers to the desired type. -- patch_torch_functions: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. -- keep_batchnorm_fp32: To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. -- master_weights: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. -- loss_scale: If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically. - - -### Naive AMP Configuration - -In Naive AMP mode, we achieved mixed precision training while maintaining compatibility with complex tensor and pipeline parallelism. -This AMP mode will cast all operations into fp16. -The following code block shows the `config.py` file for this mode. - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.NAIVE, - - # below are the default values - log_num_zeros_in_grad=False, - initial_scale=2 ** 32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2 -) -``` - -The default parameters of Naive AMP: -- log_num_zeros_in_grad(bool): return number of zeros in the gradients. -- initial_scale(int): initial scale of gradient scaler -- growth_factor(int): the growth rate of loss scale -- backoff_factor(float): the decrease rate of loss scale -- hysteresis(int): delay shift in dynamic loss scaling -- max_scale(int): maximum loss scale allowed -- verbose(bool): if set to `True`, will print debug info - -When using `colossalai.initialize`, you are required to first instantiate a model, an optimizer and a criterion. -The output model is converted to AMP model of smaller memory consumption. -If your input model is already too large to fit in a GPU, please instantiate your model weights in `dtype=torch.float16`. -Otherwise, try smaller models or checkout more parallelization training techniques! - - -## Hands-on Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) which demonstrates -the use of AMP with Colossal-AI. In this practice, we will use Torch AMP as an example, but do note that config files are provided for all AMP modes. - -### Step 1. Create a config file - -Create a `config.py` and add the `fp16` configuration. - -```python -# in config.py -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 - -fp16 = dict( - mode=AMP_TYPE.TORCH, -) - -clip_grad_norm = 1.0 -``` - -### Step 2. Import libraries in train_with_engine.py - -Create a `train_with_engine.py` and import the necessary dependencies. Remember to install `scipy` and `timm` by running -`pip install timm scipy`. - -```python -import os -import colossalai -import torch -from pathlib import Path -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import get_dataloader -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.nn.lr_scheduler import LinearWarmupLR -from timm.models import vit_base_patch16_224 -from torchvision import datasets, transforms - -``` - -### Step 3. Initialize Distributed Environment - -We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) -for other initialization methods. - -```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() - -# launch from torch -colossalai.launch_from_torch(config=args.config) - -``` - -### Step 4. Create training components - -Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is -obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` -to a path on your machine. Data will be automatically downloaded to the root path. - -```python -# build model - model = vit_base_patch16_224(drop_rate=0.1) - - # build dataloader - train_dataset = datasets.Caltech101( - root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(256), - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - Gray2RGB(), - transforms.Normalize([0.5, 0.5, 0.5], - [0.5, 0.5, 0.5]) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - - # build optimizer - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) - - # build loss - criterion = torch.nn.CrossEntropyLoss() - - # lr_scheduler - lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) -``` - -### Step 5. Inject AMP Feature - -Call `colossalai.initialize` to convert the training components to be running with FP16. - -```python -engine, train_dataloader, _, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, - ) -``` - -### Step 6. Train with Engine - -Use engine in a normal training loops. - -```python -engine.train() -for epoch in range(gpc.config.NUM_EPOCHS): - for img, label in enumerate(train_dataloader): - img = img.cuda() - label = label.cuda() - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() - lr_scheduler.step() -``` - -### Step 7. Invoke Training Scripts - -Use the following command to start the training scripts. You can change `--nproc_per_node` to use a different number of GPUs. - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py -``` - diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md index 1240b47d5d2e..8e702a578ea4 100644 --- a/docs/source/en/features/mixed_precision_training_with_booster.md +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -1,10 +1,9 @@ -# Auto Mixed Precision Training (Latest) +# Auto Mixed Precision Training Author: [Mingyan Jiang](https://github.com/jiangmingyan) **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) - [Training Booster](../basics/booster_api.md) **Related Paper** @@ -61,7 +60,7 @@ However, there are other operations, like reductions, which require the dynamic ## AMP in Colossal-AI -We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Now booster support torch amp, the other two(apex amp, naive amp) are still started by `colossalai.initialize`, if needed, please refer to [this](./mixed_precision_training.md). Next we will support `bf16`, `fp8`. +We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`, `fp8`. ### Start with Booster diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md index 955559ba2a2b..42305182b8b8 100644 --- a/docs/source/en/features/zero_with_chunk.md +++ b/docs/source/en/features/zero_with_chunk.md @@ -204,7 +204,7 @@ def main(): torch.cuda.synchronize() ``` -> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation.md) we mentioned before。 +> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation_with_booster.md) we mentioned before。 The complete example can be found on [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md deleted file mode 100644 index 812b9c34e4da..000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md +++ /dev/null @@ -1,113 +0,0 @@ -# 添加你自己的并行模式 - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - -## 引言 - -为了使研究人员和工程师能够以更少的努力将我们的系统扩展到其他新颖的大规模分布式训练算法,我们已经将训练生命周期中的各种组件解耦。你可以通过简单地继承基类来实现你自己的并行模式。 - -主要组件有: - -1. `ProcessGroupInitializer` -2. `GradientHandler` -3. `Schedule` - -**目前这需要对源代码进行一些改动,因此我们建议你用`-e`标志从源代码安装。`-e`标志使得安装是可编辑的,因此,你的代码变化将反映在你的Python运行时中。我们将在这方面努力,以避免在未来的版本中改变源代码。** - - -## 进程组初始化器 - -并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。 -Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。 - -1. 在 `colossalai.legacy.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。 - ```python - class ParallelMode(Enum): - GLOBAL = 'global' - DATA = 'data' - PIPELINE = 'pipe' - ... - - NEW_MODE = 'new_mode' # define your mode here - ``` - -2. 创建一个 `ProcessGroupInitializer`。 你可以参考 `colossalai.context.dist_group_initializer` 中给出的例子,前六个参数是固定的。 -`ParallelContext` 将为你传入这些参数。如果你需要设置其他参数,可以像下面的例子中的 `arg1, arg2` 一样,在后面添加它。 -最后,通过添加装饰器 `@DIST_GROUP_INITIALIZER.register_module` 将你的初始化程序注册到注册表。 - ```python - # sample initializer class - @DIST_GROUP_INITIALIZER.register_module - class MyParallelInitializer(ProcessGroupInitializer): - - def __init__(self, - rank: int, - world_size: int, - config: Config, - data_parallel_size: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - arg1, - arg2): - super().__init__(rank, world_size, config) - self.arg1 = arg1 - self.arg2 = arg2 - # ... your variable init - - def init_parallel_groups(self): - # initialize your process groups - pass - - ``` - 然后,你可以将你的新初始化器插入到 `colossalai.constants.INITIALIZER_MAPPING` 当前的模式与初始化映射中。你可以修改该文件或动态插入新的键值对。 - - ```python - colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' - ``` - -3. 在你的配置文件中设置你的初始化器。你可以传入你的自定义参数。这允许 - `ParallelContext` 创建你的初始化器并初始化你期望的进程组。 - - ```python - parallel = dict( - pipeline=dict(size=1), - tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode - ) - ``` - -## 梯度 Handler - -梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承 -`colossalai.legacy.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。 -如果数据并行被检测到,梯度 handler 会被自动添加进 engine。 - -你可以添加你自己的梯度 handler,如下所示: - -```python -from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.legacy.engine import BaseGradientHandler - -@GRADIENT_HANDLER.register_module -class YourGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - -``` - -之后,你可以在配置文件中指定你要使用的梯度 handler。 - -```python -gradient_handlers = [ - dict(type='YourGradientHandler'), -] -``` - -## Schedule - -Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。 -如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.legacy.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。 - diff --git a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md deleted file mode 100644 index 64e8d8bcd14a..000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md +++ /dev/null @@ -1,31 +0,0 @@ -# 定义你自己的并行模型 - -作者: Zhengda Bian, Yongbin Li - -> ⚠️ 我们正在编写此文档以使其更加详细。 我们将介绍不同并行的机制以及如何使用它们来编写模型。 - -假设您有一个具有数十亿参数的巨大 MLP 模型,其极大的隐藏层大小使其无法直接被单个 GPU 容纳。别担心,Colossal-AI 可以帮你解决这个问题。 -在 Colossal-AI 的帮助下,您可以用所熟悉的为单个 GPU 编写模型的方式编写大模型,而 Colossal-AI 会自动拆分您的模型权重,并将它们完美地分配到一组 GPU 中。我们给出一个简单的示例,展示如何在 Colossal-AI 中编写简单的 2D 并行模型。 - -## 写一个简单的2D并行模型 - -```python -from colossalai.nn import Linear2D -import torch.nn as nn - -class MLP_2D(nn.Module): - - def __init__(self): - super().__init__() - self.linear_1 = Linear2D(in_features=1024, out_features=16384) - self.linear_2 = Linear2D(in_features=16384, out_features=1024) - - def forward(self, x): - x = self.linear_1(x) - x = self.linear_2(x) - return x -``` - -## 使用预定义的模型 - -为了方便您的使用,我们在 Colossal-AI 的 Model Zoo 中提供一些流行的模型,如*BERT*, *ViT*, *MoE* 和 *GPT*,请自由地将它们定制为不同的尺寸,以满足您的特殊需求。 diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md deleted file mode 100644 index dfd1e2910b4e..000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ /dev/null @@ -1,179 +0,0 @@ -# 使用ColoTensor让串行程序像Megatron-LM一样并行 - -Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) - -**Prerequisite:** -- [ColoTensor Concepts](../basics/colotensor_concept.md) - -## 介绍 - -在新版本中,我们引入了ColoTensor。ColoTensor为用户使用并行训练提供了极大的便利,使得用户可以在原本的串行代码上,通过较小的修改将训练改为并行。在本教程中,我们将说明如何修改训练模型以自动使代码采取像 Megatron-LM 一样的方式并行训练。我们以 HuggingFace 提供的 GPT-2 模型为例,并提供一种方式让你可以在单个GPU上预训练GPT-2模型。 - -Megatron-LM 提供了一个具有影响力的并行化范式,这个范式主要应用于Transformer大模型的训练。然而,为了大规模训练 Transformer 语言大模型,用户必须使用Megatron-LM提供的特殊模块来构建他们的模型。这给用户带来了一些困难的工作,例如从预先训练的模型中加载权重,或是构建自己的并行训练模型。为了减轻用户的麻烦,我们提供 ColoTensor 类,以完成自动启用张量模型并行。 - -## 定义模型和损失函数 - -首先,我们直接调用 HuggingFace 库中的 GPTModel 和 GPTLoss。 - -```python -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -class GPTLMModel(nn.Module): - def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, - n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids, attention_mask): - # Only return lm_logits - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -``` - -## 对GPT-2的简短回顾 - -现在,我们回顾一下 GPT-2 模型的结构。每个 GPT-2 模型都可以表示为一个 DAG。如下图所示,每个圆圈代表一个算子,每个方块代表一个权重。每个箭头表示输入数据的流向,而箭头旁边的符号表示输入数据的形状。 - -然后,让我们深入了解一下这个 GPT-2 模型。它由三部分组成,分别是**嵌入模块**、**转换器层**和**分类头**。 - -嵌入模块包含两个权重,符号嵌入权重和位置嵌入权重。在嵌入模块的前向操作之后,原始输入数据的所有序列中的每个单词都会被嵌入到隐藏状态。 - -
- -
嵌入模块
-
- -每个转换器层包含两个块。自注意操作在第一个块中调用,同时一个双层感知器位于第二个块中。 - -
- -
转换器层
-
- -最后,分类头只是一个不加偏差的线性模块,里面只有一个线性权重。 - -## 应用ColoTensor - -两个步骤使您的串行代码采取 Megatron-LM 张量并行风格。 -1. 在ColoInitContext的上下文中初始化模型。 -2. 为每个参数设置 ColoTensorSpec。 - -### 使用 ColoInitContext 初始化 - -我们应该在 ColoInitContext 中构建模型。在该种上下文中,任何初始化的参数都将转换为 ColoParameter 并自动移动到相应的设备上。 - -```python -from colossalai.utils.model.colo_init_context import ColoInitContext - -with ColoInitContext(device=torch.device('cpu')): - model = GPTLMModel() -``` - -### 为每个参数设置 ColoTensorSpec - -模型创建完成后,我们通过ProcessGroup建立分布式环境。这里,我们将张量并行度指定为所有GPU的数量,即数据并行度为一。 - -```python -import torch.distributed as dist -from colossalai.tensor import ProcessGroup - -pg = ProcessGroup(tp_degree=dist.get_world_size()) -``` - -现在,我们需要一些辅助函数为下一步做准备。我们定义了两个函数来切分参数。Megatron-LM张量并行需要沿参数的第一维或最后一维切分参数张量。 - -```python -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup - -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -然后我们使模型采用张量并行。根据 Megatron 中使用的张量并行,应该沿着张量的最后一个维度进行切片,包括符号嵌入的权重,位置嵌入的权重,自注意力块中的所有线性权重和偏差,以及每个双层感知器中的第一个线性权重和偏差。且需要沿第一个维度切分双层感知器中的第二个线性权重。 - -```python -for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # set process group for all parameters - param.set_process_group(pg) - - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # column slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # column slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # column slice -``` - -修改后的模型如下图所示。 - -嵌入模块: - -
- -
修改后的嵌入模块
-
- -转换器层: - -
- -
修改后的转换器层
-
- -一旦用户指定了每个参数的在并行中的分布模式,ColoTensor 就能够推断出所有算子的计算模式,包括矩阵乘法、线性函数、torch.nn.functional 中的其他逐元素函数,以及其他的一些常用函数。这样,用户可以像往常一样训练他们的模型。 - -在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): - from colossalai.zero import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_m=32) - return model -``` - -## 在单个GPU上预训练GPT-2 - -我们做的上述优化让我们可以在单GPU上训练GPT-2模型,只需要将`run.sh`中设置参数`GPUNUM`=1,再运行文件时就可以在单个GPU上完成模型的训练。 - -GPT-2 示例在[Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 - - - diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md deleted file mode 100644 index ab2413e990f7..000000000000 --- a/docs/source/zh-Hans/basics/colotensor_concept.md +++ /dev/null @@ -1,99 +0,0 @@ -# ColoTensor Concepts - -Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) - -> ⚠️ 此页面上的信息已经过时并将被废弃。 - -**Prerequisite:** -- [Colossal-AI Overview](../concepts/colossalai_overview.md) -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) - -## Introduction - -在ColossalAI 0.1.8 版本之后,[ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) 成为 ColossalAI 中张量的基本数据结构。 它是 torch.Tensor 的子类,可以当做 PyTorch Tensor使用。 此外,一些独特的功能使其能够表示一个payload分布在多个 GPU 设备上的Global Tensor,并提供一些列方式操作这个Global Tensor。 在 ColoTensor 的帮助下,用户可以以类似编写串行程序方式,编写的分布式 DNN 训练程序。 - -ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) -来描述张量的payload分布和计算模式。 - -- ProcessGroup:如何将进程组织为通信组。 -- Distributed Spec:张量如何在进程组之间分布。 -- Compute Spec:计算过程中如何使用张量。 - -我们一一详述。 - -## ProcessGroup - -[ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) 类的一个实例描述了如何在进程组中组织进程。进程组内的进程可以一起参与同一个集合通信,比如allgather, allreduce等。进程组组织方式被张量的并行策略支配。比如,如果用户定义了Tensor的张量并行(TP),数据并行(DP)方式,那么进程组的进程组织方式将被自动推导出来。 进程组设置可能因不同的张量而异。 因此,它使我们能够支持更复杂的混合并行。流水线并行(PP)定义不在ProcessGroup中描述,它需要另一套机制,我们将在未来补充ColoTensor应用于PP的相关内容。 - -目前,ColoTensor 的一个进程组由 tp_degree 和 dp_degree 两种配置定义。 在 DP+TP 混合并行的情况下,可以将设备视为 2D 网格。 我们将 TP 通信组放置在设备网格的前导低维上,然后将数据并行组放置在设备网格的高维上。 原因是张量并行比数据并行具有更大的通信开销。 相邻设备放置在一个 TP 进程组内,并且通常放置在同一个节点中。 - -考虑到8个进程配置为tp_degree=4,dp_degree=2,布局如下图。 进程组 tp0 包含 gpu 0,1,2,3。 进程 dp1 包含 gpu 1 和 5。 - -
- -
Process Group using tp_degree=4, dp_degree=2
-
- -## Distributed Spec - -[Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html)描述了 ColoTensor 如何在 ProcessGroup 中分布。 - -张量在 DP 进程组之间的分布方式是自动导出的,不需要用户手动指定。 如果这个张量是一个模型参数,它会在 DP 进程组中被复制。 如果是activation张量,则沿tensor最高维度在DP进程组中进行平均分割。 - -因此,在使用 Distributed Spec 时,我们只需要描述张量在 TP 进程组之间的分布方式即可。 TP 进程组目前有两种分布式规范,即 [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec)和[ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec)。 ShardSpec 需要指定分区的维度索引 dim 和分区个数 num_partitions。 目前,我们仅支持在单个dim上进行拆分。 TP进程组上不同的dist spec可以通过set_dist_spec()接口相互转换。这些转化操作可以被记录在PyTorch的自动求导机制中,并在反向传播时候触发对应的反向操作。 - -## Compute Spec - -[ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec)类描述Tensor如何参与计算。目前,我们将作为module parameter的ColoTensor设置正确的Compute Pattern。可以触发正取的计算模式。具体应用方式我们会在接下来的文档中展示。 - -## ColoParameter - -[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter)是ColoTensor的子类。用来声明Parameter。他和ColoTensor关系和Torch.Tensor和torch.Parameter一致。后者可以让tensor出现在module的parameters()和name_parameters() 的返回值中。 - -## Example - -让我们看一个例子。 使用 tp_degree=4, dp_degree=2 在 8 个 GPU 上初始化并Shard一个ColoTensor。 然后tensor被沿着 TP 进程组中的最后一个维度进行分片。 最后,我们沿着 TP 进程组中的第一个维度(dim 0)对其进行重新Shard。 我们鼓励用户运行代码并观察每个张量的形状。 - - -```python -import torch -import torch.multiprocessing as mp -from colossalai.utils import print_rank_0 -from functools import partial - -import colossalai -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.testing import spawn - -import torch - -def run_dist_tests(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=2, dp_degree=2) - - torch.manual_seed(0) - local_tensor = torch.randn(2, 3, 1).cuda() - print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") - - spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - t1 = ColoTensor.from_torch_tensor(local_tensor, spec) - t1 = t1.to_replicate() - print_rank_0(f"shape {t1.shape}, {t1.data}") - - spec2 = ShardSpec([0], [pg.tp_world_size()]) - t1.set_dist_spec(spec2) - print_rank_0(f"shape {t1.shape}, {t1.data}") - -def test_dist_cases(world_size): - spawn(run_dist_tests, world_size) - -if __name__ == '__main__': - test_dist_cases(4) -``` - -:::caution - -The ColoTensor is an experimental feature and may be updated. - -::: diff --git a/docs/source/zh-Hans/basics/configure_parallelization.md b/docs/source/zh-Hans/basics/configure_parallelization.md deleted file mode 100644 index 0c2a66572d60..000000000000 --- a/docs/source/zh-Hans/basics/configure_parallelization.md +++ /dev/null @@ -1,138 +0,0 @@ -# 并行配置 - -作者: Shenggui Li, Siqi Mai - -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Booster插件](../basics/booster_plugins.md)页面查阅更新。 - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [并行技术](../concepts/paradigms_of_parallelism.md) -- [构建配置文件](./define_your_config.md) - - -## 简介 - -我们在 Colossal-AI 中支持多种并行技术。代码库中的混合并行是指您可以轻松地结合数据并行、流水线并行和张量并行(1D、2D、2.5D、3D)的优势共同来进行并行训练。 - -每种并行方式需要不同的网络拓扑结构,因此要初始化不同的进程组。您可以通过在配置文件中设置 `parallel` 来初始化相应的进程组。 `parallel` 的配置必须遵从以下格式。数据并行度的大小将被根据您对流水线并行和张量并行的输入自动推断。`colossalai.launch` 将根据您的配置自动初始化这些分布式进程组。 - -我们为您提供了一些配置的例子以供参考。 - -```python -# sampler format -parallel = dict( - pipeline=dict("size": int), - tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) -) - -# this is ok -parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=4, mode='2d') -) - -# this is ok -parallel = dict( - pipeline=2, - tensor=dict(size=4, mode='2d') -) - -# this is not ok -# as you need to specify the mode for tensor parallelism -parallel = dict( - pipeline=2, - tensor=4 -) - -# this is ok as well as tensor will be default to size 1 -# and mode None -parallel = dict( - pipeline=2 -) - -# this is ok as well as pipeline will default to size 1 -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -``` - -关键字 `size` 指的是并行维度的并行大小。 例如,流水线大小为2意味着有 -将有2个流水线阶段。张量并行配置中的关键字 `mode` 意味着相应的张量并行技术 -将被初始化,如1D、2D、2.5D、3D。 - -**您也可以选择不在您的配置中使用 "并行",此时流水线和张量的并行度都将默认为大小1。** - -**GPU的总数量必须等于` 数据并行大小 x 张量并行大小 x 流水线并行大小` 。** - -## 数据并行 - -数据并行是最常见的分布式训练方式。它将数据分割成几个碎片分别在每个设备上进行训练。数据并行的配置会自动检测并为您设置。您不需要在您的配置中明确地设置它们。在Colossal-AI 中,有两种方法来处理数据并行的 all-reduce。 - -1. 如果您设置了梯度handler,梯度handler将会all-reduce梯度。 -2. 若没有指定相应的配置,Colossal-AI 将会使用 PyTorch 的 DistributedDataParallel。 - -在大多数情况下,若您对梯度没有复杂的处理的需求,您将会使用第二种模式。 - -## 1D, 2D, 2.5D 和 3D 并行 - -为了实现混合并行,我们提供了一系列张量并行方法。您可以阅读相应的学术论文进行深入的了解。这些并行模式需要和 Colossal-AI 提供的分布式层一同工作。 - -- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) - -- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) - 2D 并行基于 SUMMA 矩阵乘法,它将输入数据、模型权重和层输出切分成两个不同的维度。 这些张量块分布在 `P = N^2` 设备的二维网格上,其中 `N` 是单一维度上张量块的数量。 - -- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) - 在 2.5D 矩阵乘法的启发下,2.5D 并行引入了一种新的张量并行,进一步将2D张量并行化。其中,`P = N^2 ∗ d` 个处理器被分配到 `d` 层, 每层独立进行矩阵乘法运算,维度为 `N`。 - -- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) - 我们还介绍了一种 3D 张量并行方法,在三维处理器立方体上并行化神经网络。这种方法在数量为 `P` 的处理器上实现了最佳的 `O(P^{1/3})` 通信开销,而计算和内存的使用都是通过优化的参数和激活的负载平衡来实现的。同时,通过优化参数和 activations 的负载平衡,计算和内存的使用都是均匀分布的。 - -```python -# 1D parallel -parallel = dict( - tensor=dict(size=4, mode='1d') -) - -# 2D parallel -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -# 2.5D parallel -parallel = dict( - tensor=dict(size=8, mode='2.5d', depth=2) -) - -# 3D parallel -parallel = dict( - tensor=dict(size=8, mode='3d') -) -``` - -当您在配置中指定了张量并行模式,您就可以使用其相应的分布式算子。例如,若您设置模式为 `2d`,那么在模型构建中就能使用 `colossalai.nn.Linear2D` 了。 - - -## 流水线并行 - -流水线并行是将模型按层分成几个部分。例如,假设我们有一个简单的模型,它由两个线性层组成。我们有两个 GPU,我们可以将第一个线性层分配给第一个 GPU 而第二层则分配给第二个 GPU。 - -您可以在您的配置文件中设置流水线并行度的大小。当流水线并行度大于1,Colossal-AI 将会自动地创建流水线并行的 schedule,这将会为您定义好模型训练的 `forward` 和 `backward`。 - -```python -parallel = dict( - pipeline=dict(size=4), # number of pipeline stages -) -``` - -## 序列并行 - -针对处理大图片、视频、长文本、长时间医疗监控等数据的需要,Colossal-AI 还提供了序列并行的方法。该方法是在论文[Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120)中提出的。您可以指定模式为 `sequence` 来初始化进程组。 - - -```python -parallel = dict( - tensor=dict(size=4, mode='sequence') -) -``` diff --git a/docs/source/zh-Hans/basics/define_your_config.md b/docs/source/zh-Hans/basics/define_your_config.md deleted file mode 100644 index 720e75805e8d..000000000000 --- a/docs/source/zh-Hans/basics/define_your_config.md +++ /dev/null @@ -1,73 +0,0 @@ -# 构建配置文件 - -作者: Guangyang Lu, Shenggui Li, Siqi Mai - -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Booster API](../basics/booster_api.md)页面查阅更新。 - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [Colossal-AI 总览](../concepts/colossalai_overview.md) - - -## 简介 - -在 Colossal-AI 中,我们需要一个配置文件来指定系统在训练过程中要注入的特征。在本教程中,我们将向您介绍如何构建您的配置文件以及如何使用这个配置文件。使用配置文件有以下一些好处: - -1. 您可以在不同的配置文件中存储您的特征配置和训练超参数。 -2. 对于我们未来发布的新功能,您亦可以在配置中指定,而无需改变训练脚本的代码。 - -在本教程中,我们将向您介绍如何构建您的配置文件。 - -## 配置定义 - -在一个配置文件中,有两种类型的变量。一种是作为特征说明,另一种是作为超参数。所有与特征相关的变量都是保留关键字。例如,如果您想使用混合精度训练,需要在 config 文件中使用变量名`fp16`,并遵循预先定义的格式。 - -### 功能配置 - -Colossal-AI 提供了一系列的功能来加快训练速度。每个功能都是由配置文件中的相应字段定义的。在本教程中,我们不会给出所有功能的配置细节,而是提供一个如何指定一个功能的说明。**每个功能的细节可以在其各自的教程中找到。** - -为了说明配置文件的使用,我们在这里使用混合精度训练作为例子。您需要遵循以下步骤。 - -1. 创建一个配置文件(例如 `config.py`,您可以指定任意的文件名)。 -2. 在配置文件中定义混合精度的配置。例如,为了使用 PyTorch 提供的原始混合精度训练,您只需将下面这几行代码写入您的配置文件中。 - - ```python - from colossalai.amp import AMP_TYPE - - fp16 = dict( - mode=AMP_TYPE.TORCH - ) - ``` - -3. 当启动分布式环境时,向 Colossal-AI 指定您的配置文件的位置。比如下面的例子是配置文件在当前目录下。 - - ```python - import colossalai - - colossalai.launch(config='./config.py', ...) - ``` - -这样,Colossal-AI 便知道您想使用什么功能,并会在 `colossalai.initialize` 期间注入您所需要的功能。 - -### 全局超参数 - -除了功能的配置,您还可以在配置文件中定义训练的超参数。当您想进行多个实验时,这将会变得非常方便。每个实验的细节都可以放在独立的配置文件中,以避免混乱。这些参数将被存储在全局并行环境中,可以在训练脚本中访问。 - -例如,您可以在配置文件中指定批量大小。 - -```python -BATCH_SIZE = 32 -``` - -启动后,您能够通过全局并行上下文访问您的超参数。 - -```python -import colossalai -from colossalai.core import global_context as gpc - -colossalai.launch(config='./config.py', ...) - -# access your parameter -print(gpc.config.BATCH_SIZE) - -``` diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md deleted file mode 100644 index ed5100299212..000000000000 --- a/docs/source/zh-Hans/basics/engine_trainer.md +++ /dev/null @@ -1,387 +0,0 @@ -# 如何在训练中使用 Engine 和 Trainer - -作者: Shenggui Li, Siqi Mai - -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Booster API](../basics/booster_api.md)页面查阅更新。 - -**预备知识:** -- [初始化功能](./initialize_features.md) - -## 简介 - -在本教程中,您将学习如何使用 Colossal-AI 中提供的 Engine 和 Trainer 来训练您的模型。在深入研究细节之前,我们想先解释一下 Engine 和 Trainer 的概念。 - -### Engine - -Engine 本质上是一个模型、优化器和损失函数的封装类。当我们调用 `colossalai.initialize` 时,一个 Engine 对象将被返回,并且配备了在您的配置文件中指定的梯度剪裁、梯度累计和 ZeRO 优化器等功能。 - -Engine 将使用与 PyTorch 训练组件类似的 API,因此您只需对代码进行微小的修改即可。 - -下表展示了Engine的常用API。 - -| 组件 | 功能 | PyTorch | Colossal-AI | -| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | -| optimizer | 迭代前将所有梯度设置为零 | optimizer.zero_grad() | engine.zero_grad() | -| optimizer | 更新参数 | optimizer.step() | engine.step() | -| model | 进行一次前向计算 | outputs = model(inputs) | outputs = engine(inputs) | -| criterion | 计算loss值 | loss = criterion(output, label) | loss = engine.criterion(output, label) | -| criterion | 反向计算 | loss.backward() | engine.backward(loss) | - -我们需要这样一个 Engine 类的原因是,我们可以添加更多的功能,同时将实现隐藏在 -`colossalai.initialize` 函数中实现。 -假如我们要添加一个新的功能,我们可以在 `colossalai.initialize` 函数中完成对于模型、优化器、数据加载器和损失函数的功能诠释。不管中间的过程有多复杂,最终我们呈现的以及用户需要使用的只有一个 Engine 类,这将十分便捷。 -用户只需要在最小范围内修改他们的代码,将普通的 PyTorch APIs 调整为 Colossal-AI -Engine 的 API。通过这种方式,他们可以享受更多的功能来进行有效的训练。 - -以下是一个简单的例子: - -```python -import colossalai - -# build your model, optimizer, criterion, dataloaders -... - -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -for img, label in train_dataloader: - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() -``` - -### Trainer - -Trainer 是一个更高级的封装器,用户可以用更少的代码行来执行训练。 由于 Trainer 的使用会更加简单,相较于 Engine,它会缺少一点灵活性。 Trainer 被设计为进行前向和反向计算来进行模型权重的更新。通过传递 Engine 对象,我们可以很容易地创建一个 Trainer。 -Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除非我们想使用流水线并行,否则我们把这个值设为 `None`。如果您想探索更多关于这个参数的内容,您可以前往流水线并行的相关教程。 - -```python -from colossalai.logging import get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks - -# build components and initialize with colossalai.initialize -... - -# create a logger so that trainer can log on the console -logger = get_dist_logger() - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) -``` - -在 Trainer 中,用户可以定制一些 hooks,并将这些 hooks 附加到 Trainer 上。hook 将根据训练方案定期地执行生命周期函数。例如,基于用户是想在每次训练迭代后还是只在整个训练周期后更新学习率, -`LRSchedulerHook` 将会在 `after_train_iter` 或 `after_train_epoch` 阶段执行 `lr_scheduler.step()` 去为用户更新学习率。您可以将 hook 存储在一个列表中并将其传递给 `trainer.fit` 方法。`trainer.fit` 方法将根据您的参数执行训练和测试。如果 `display_process` 为 True,将在您的控制台显示一个进度条,以显示训练的过程。 - - -```python -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), -] - -# start training -trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - -如果您想定制您的 hook 类,您可以继承 `hooks.BaseHook` 并重写您想要的生命周期方法。下面提供了一个例子来演示如何创建一个简单的关于日志信息的 hook,以供您参考。 - -```python -from colossalai.logging import get_dist_logger -from colossalai.legacy.trainer import hooks - -class LogMessageHook(hooks.BaseHook): - - def __init__(self, priority=10): - self._logger = get_dist_logger() - - def before_train(self, trainer): - self._logger.info('training starts') - - def after_train(self, trainer): - self._logger.info('training finished') - - -... - -# then in your training script -hook_list.append(LogMessageHook()) -``` - - - -在下面的章节中,您将会详细地了解到如何用 Engine 和 Trainer 来训练 ResNet 模型。 - - -## ResNet - -### 总览 - -在本节中,我们将介绍: - -1. 使用一个 Engine 在 CIFAR10 数据集上训练 ResNet34 模型 -2. 使用一个 Trainer 在 CIFAR10 数据集上训练 ResNet34 模型 - -项目结构如下: - -```bash --- config.py --- run_resnet_cifar10_with_engine.py --- run_resnet_cifar10_with_trainer.py -``` - -对于使用 Engine 或 Trainer,步骤 1-4 是通用的。 因此,步骤 1-4 + 步骤 5 将会是对应 `run_resnet_cifar10_with_engine.py` 而 步骤 1-4 + 步骤6 则对应 `run_resnet_cifar10_with_trainer.py`。 - -### 牛刀小试 - -#### 步骤 1. 创建配置文件 - -在你的项目文件夹中,创建一个 `config.py`。这个文件是用来指定一些您可能想用来训练您的模型的特征。下面是一个配置文件的例子。 - -```python -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -NUM_EPOCHS = 200 - -fp16=dict( - mode=AMP_TYPE.TORCH -) -``` - -在这个配置文件中,我们指定要在每个 GPU 上使用批大小为128,并运行200个 epoch。这两个参数是在 `gpc.config` 中体现的。例如,您可以使用 `gpc.config.BATCH_SIZE` 来访问您存储在配置文件中的批大小值。而 `fp16` 配置则会告诉 `colossalai.initialize` 使用 PyTorch 提供的混合精度训练,以更好的速度和更低的内存消耗来训练模型。 - -#### 步骤 2. 初始化分布式环境 - -我们需要初始化分布式训练环境。这在 [启动 Colossal-AI](./launch_colossalai.md) 中有相应的教程。在当前的演示中,我们使用 `launch_from_torch` 和 PyTorch 启用工具。 - -```python -import colossalai - -# ./config.py refers to the config file we just created in step 1 -colossalai.launch_from_torch(config='./config.py') -``` - -#### 步骤 3. 创建所有的训练组件 - -这时,我们可以创建用于训练的所有组件,包括: - -1. 模型 -2. 优化器 -3. 损失函数 -4. 训练/测试数据加载器 -5. 学习率调度器 -6. 日志记录器 - - - -为了构建这些组件,您需要导入以下模块。 - -```python -from pathlib import Path -from colossalai.logging import get_dist_logger -import torch -import os -from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader -from torchvision import transforms -from colossalai.nn.lr_scheduler import CosineAnnealingLR -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet34 -``` - - - -然后按照通常在PyTorch脚本中构建组件的方式来构建组件。在下面的脚本中,我们将CIFAR10数据集的根路径设置为环境变量 `DATA`。您可以把它改为您想要的任何路径,例如,您可以把 `root=Path(os.environ['DATA'])` 改为 `root='./data'` ,这样就不需要设置环境变量。 - -```python -# build logger -logger = get_dist_logger() - -# build resnet -model = resnet34(num_classes=10) - -# build datasets -train_dataset = CIFAR10( - root='./data', - download=True, - transform=transforms.Compose( - [ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -test_dataset = CIFAR10( - root='./data', - train=False, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -# build dataloaders -train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -test_dataloader = get_dataloader(dataset=test_dataset, - add_sampler=False, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -# build criterion -criterion = torch.nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) - -# lr_scheduler -lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) -``` - -#### 步骤 4. 用 Colossal-AI 进行初始化 - -接下来,重要的一步是通过调用 `colossalai.initialize` 获得 Engine。正如 `config.py` 中所述,我们将使用混合精度训练来训练 ResNet34 模型。`colossalai.initialize` 将自动检查您的配置文件,并将相关特征分配给您的训练组件。这样一来,我们的 Engine 已经能够进行混合精度训练,而您不需要进行额外的处理。 - -```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader, - ) -``` - - - -#### 步骤 5. 用 Engine 进行训练 - -当所有的训练组件都准备好后,我们就可以像使用 PyTorch 一样训练 ResNet34 了。 - -```python -for epoch in range(gpc.config.NUM_EPOCHS): - # execute a training iteration - engine.train() - for img, label in train_dataloader: - img = img.cuda() - label = label.cuda() - - # set gradients to zero - engine.zero_grad() - - # run forward pass - output = engine(img) - - # compute loss value and run backward pass - train_loss = engine.criterion(output, label) - engine.backward(train_loss) - - # update parameters - engine.step() - - # update learning rate - lr_scheduler.step() - - # execute a testing iteration - engine.eval() - correct = 0 - total = 0 - for img, label in test_dataloader: - img = img.cuda() - label = label.cuda() - - # run prediction without back-propagation - with torch.no_grad(): - output = engine(img) - test_loss = engine.criterion(output, label) - - # compute the number of correct prediction - pred = torch.argmax(output, dim=-1) - correct += torch.sum(pred == label) - total += img.size(0) - - logger.info( - f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) -``` - -#### 步骤 6. 用 Trainer 进行训练 - -如果您想用 Trainer 进行训练,您可以参考下面的代码进行您的实验。 - - -```python -from colossalai.legacy.nn.metric import Accuracy -from colossalai.legacy.trainer import Trainer, hooks - - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) - -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LogMemoryByEpochHook(logger) -] - -# start training -# run testing every 1 epoch -trainer.fit( - train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - - - -#### 步骤 7. 开始分布式训练 - -最后,我们可以使用 PyTorch 提供的分布式启动器来调用脚本,因为我们在步骤2中使用了 `launch_from_torch`。您需要把`` 替换成您机器上可用的GPU数量。如果您只想使用一个 GPU,您可以把这个数字设为1。如果您想使用其他的启动器,请您参考如何启动 Colossal-AI 的教程。 - - -```bash -# with engine -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -# with trainer -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py -``` - diff --git a/docs/source/zh-Hans/basics/initialize_features.md b/docs/source/zh-Hans/basics/initialize_features.md deleted file mode 100644 index 1c28d658e1bc..000000000000 --- a/docs/source/zh-Hans/basics/initialize_features.md +++ /dev/null @@ -1,48 +0,0 @@ -# 初始化功能 - -作者: Shenggui Li, Siqi Mai - -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Booster API](../basics/booster_api.md)页面查阅更新。 - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [Colossal-AI 总览](../concepts/colossalai_overview.md) - -## 简介 - -在本教程中,我们将介绍 `colossalai.initialize` 的使用。 它包含了如何将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 调用 `colossalai.initialize` 是您进入训练循环前的基本操作。 - -在下面一节中,我们将介绍 `colossalai.initialize` 是如何工作的以及使用中我们要注意的细节。 - -## 使用 - -在一个典型的工作流程中,我们将在训练脚本的开始启动分布式环境。 -之后,我们将实例化我们的对象,如模型、优化器、损失函数、数据加载器等。此时,我们可以使用 `colossalai.initialize` 便捷地为这些对象注入特征。 -具体细节请看以下的伪代码例子。 - -```python -import colossalai -import torch -... - - -# launch distributed environment -colossalai.launch(config='./config.py', ...) - -# create your objects -model = MyModel() -optimizer = torch.optim.Adam(model.parameters(), lr=0.001) -criterion = torch.nn.CrossEntropyLoss() -train_dataloader = MyTrainDataloader() -test_dataloader = MyTrainDataloader() - -# initialize features -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -``` - - `colossalai.initialize` 将返回一个 `Engine` 对象。 该对象把模型、优化器和损失函数封装起来。 **`Engine` 对象会以配置文件中指定的特征运行。** -关于 `Engine` 的更多使用细节可以在 [在训练中使用Engine和Trainer](./engine_trainer.md) 中获取。 diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md deleted file mode 100644 index 4a49d373a2a4..000000000000 --- a/docs/source/zh-Hans/basics/model_checkpoint.md +++ /dev/null @@ -1,64 +0,0 @@ -# 模型Checkpoint - -作者 : Guangyang Lu - -> ⚠️ 此页面上的信息已经过时并将被废弃。请在[Booster Checkpoint](../basics/booster_checkpoint.md)页面查阅更新。 - -**预备知识:** -- [Launch Colossal-AI](./launch_colossalai.md) -- [Initialize Colossal-AI](./initialize_features.md) - -**示例代码:** -- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) - -**函数是经验函数.** - -## 简介 - -本教程将介绍如何保存和加载模型Checkpoint。 - -为了充分利用Colossal-AI的强大并行策略,我们需要修改模型和张量,可以直接使用 `torch.save` 或者 `torch.load` 保存或加载模型Checkpoint。在Colossal-AI中,我们提供了应用程序接口实现上述同样的效果。 - -但是,在加载时,你不需要使用与存储相同的保存策略。 - -## 使用方法 - -### 保存 - -有两种方法可以使用Colossal-AI训练模型,即使用engine或使用trainer。 -**注意我们只保存 `state_dict`.** 因此,在加载Checkpoint时,需要首先定义模型。 - -#### 同 engine 保存 - -```python -from colossalai.utils import save_checkpoint -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -for epoch in range(num_epochs): - ... # do some training - save_checkpoint('xxx.pt', epoch, model) -``` - -#### 用 trainer 保存 -```python -from colossalai.legacy.trainer import Trainer, hooks -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -trainer = Trainer(engine, ...) -hook_list = [ - hooks.SaveCheckpointHook(1, 'xxx.pt', model) - ...] - -trainer.fit(... - hook=hook_list) -``` - -### 加载 - -```python -from colossalai.utils import load_checkpoint -model = ... -load_checkpoint('xxx.pt', model) -... # train or test -``` - diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index 93fe9ea99422..fb6fd90ec4c2 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -2,11 +2,8 @@ 作者: Zhengda Bian, Yongbin Li -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) -**示例代码**xw +**示例代码** - [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) **相关论文** diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md index a8e5cf4bfb47..0cb7968c8103 100644 --- a/docs/source/zh-Hans/features/2D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md @@ -3,8 +3,6 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) **示例代码** diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md index 6b0f1a301804..308638a359f1 100644 --- a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md @@ -3,8 +3,6 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) - [2D 张量并行](./2D_tensor_parallel.md) diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md index f6154559ec28..bf403d2d9636 100644 --- a/docs/source/zh-Hans/features/3D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md @@ -3,8 +3,6 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) - [2D 张量并行](./2D_tensor_parallel.md) diff --git a/docs/source/zh-Hans/features/gradient_accumulation.md b/docs/source/zh-Hans/features/gradient_accumulation.md deleted file mode 100644 index fc8b29bbe8f1..000000000000 --- a/docs/source/zh-Hans/features/gradient_accumulation.md +++ /dev/null @@ -1,41 +0,0 @@ -# 梯度累积 (旧版本) - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) - -## 引言 - -梯度累积是一种常见的增大训练 batch size 的方式。 在训练大模型时,内存经常会成为瓶颈,并且 batch size 通常会很小(如2),这导致收敛性无法保证。梯度累积将多次迭代的梯度累加,并仅在达到预设迭代次数时更新参数。 - -## 使用 - -在 Colossal-AI 中使用梯度累积非常简单,仅需将下列配置添加进 config 文件。其中,整数值代表期望梯度累积的次数。 - -```python -gradient_accumulation = -``` - -## 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) -来展现梯度累积。在这个例子中,梯度累积次数被设置为4,你可以通过一下命令启动脚本 - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -``` - -你将会看到类似下方的文本输出。这展现了梯度虽然在前3个迭代中被计算,但直到最后一次迭代,参数才被更新。 - -```text -iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) -``` - diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md index d121b161b9ff..3ad9b2e07a95 100644 --- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -1,9 +1,8 @@ -# 梯度累积 (新版本) +# 梯度累积 作者: [Mingyan Jiang](https://github.com/jiangmingyan) **前置教程** -- [定义配置文件](../basics/define_your_config.md) - [训练中使用Booster](../basics/booster_api.md) ## 引言 diff --git a/docs/source/zh-Hans/features/gradient_clipping.md b/docs/source/zh-Hans/features/gradient_clipping.md deleted file mode 100644 index 2f62c31766a6..000000000000 --- a/docs/source/zh-Hans/features/gradient_clipping.md +++ /dev/null @@ -1,53 +0,0 @@ -# 梯度裁剪(旧版本) - -作者: Boxiang Wang, Haichen Huang, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) - -**相关论文** -- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) - -## 引言 - -为了加快训练过程和寻求全局最优以获得更好的性能,越来越多的学习率调度器被提出。人们通过控制学习率来调整训练中的下降速度。这使得梯度向量在每一步都能更好地统一。在这种情况下,下降速度可以按预期被控制。 -因此,梯度裁剪,一种可以将梯度向量归一化,以将其限制在统一长度的技术,对于那些希望模型性能更好的人来说是不可或缺的。 - -在使用 Colossal-AI 时,你不必担心实现梯度剪裁,我们以一种有效而方便的方式支持梯度剪裁。你所需要的只是在你的配置文件中增加一个命令。 - -## 为什么应该使用 Colossal-AI 中的梯度裁剪 - -我们不建议用户自己编写梯度剪裁,因为朴素的梯度剪裁在应用张量并行、流水线并行、MoE 等功能时可能会失败。 - -根据下图,每个 GPU 只拥有线性层中权重的一部分参数。为了得到线性层权重的梯度向量的正确范数,每个 GPU 中的每个梯度向量的范数应该相加。更复杂的是,偏置的分布不同于权重的分布。通信组在求和运算中有所不同。 - -(注: 这种情况是旧版本的 2D 并行,在代码中的实现是不一样的。但这是一个很好的例子,能够说明在梯度剪裁中统一所有通信的困难。) - -
- -
参数分布
-
- -不用担心它,因为 Colossal-AI 已经为你处理好。 - -### 使用 -要使用梯度裁剪,只需在配置文件中添加梯度裁剪范数即可。 - -```python -clip_grad_norm = 1.0 -``` - -### 实例 - -我们提供了一个展现梯度裁剪的[运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) -。在本例中,我们将梯度裁剪范数设置为1.0,你可以使用以下命令运行脚本: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py -``` - - diff --git a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md index 3c61356dd0d5..fdec09bf128a 100644 --- a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md @@ -1,9 +1,8 @@ -# 梯度裁剪 (新版本) +# 梯度裁剪 作者: [Mingyan Jiang](https://github.com/jiangmingyan) **前置教程** -- [定义配置文件](../basics/define_your_config.md) - [booster使用](../basics/booster_api.md) **相关论文** diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md deleted file mode 100644 index 3b1140409ba8..000000000000 --- a/docs/source/zh-Hans/features/gradient_handler.md +++ /dev/null @@ -1,60 +0,0 @@ -# 梯度 Handler - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) - -## 引言 - -在分布式训练中,每次迭代结束时都需要梯度同步。这很重要,因为我们需要确保在不同的机器中使用相同的梯度更新参数,以便生成的参数都一样。这通常在数据并行中看到,因为在数据并行中的模型是直接复制的。 - -在 Colossal-AI 中,我们为用户提供了一个接口来定制他们想要如何处理同步。这为实现新的并行方法等情况带来了灵活性。 - -当梯度 Handler 被使用时, PyTorch 的 `DistributedDataParallel` 将不再被使用,因为它会自动同步梯度. - -## 定制你的梯度 Handler - -要实现定制的梯度Handler,需要遵循以下步骤。 -1. 继承Colossal-AI中的 `BaseGradientHandler` -2. 将梯度Handler注册进 `GRADIENT_HANDLER` -3. 实现 `handle_gradient` - -```python -from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.legacy.engine.gradient_handler import BaseGradientHandler - - -@GRADIENT_HANDLER.register_module -class MyGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - - -``` - - -## 使用 - -要使用梯度 Handler,需要在配置文件中指定梯度 Handler。梯度 Handler 将自动构建并连接到 Engine。 - -```python -gradient_handler = [dict(type='MyGradientHandler')] -``` - - -### 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) -展现梯度 Handler 的使用. 在这个例子中,我们使用 `DataParallelGradientHandler` 而不是 PyTorch 的 -`DistributedDataParallel` 实现数据并行. - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py -``` - diff --git a/docs/source/zh-Hans/features/mixed_precision_training.md b/docs/source/zh-Hans/features/mixed_precision_training.md deleted file mode 100644 index 35a73f1adbcd..000000000000 --- a/docs/source/zh-Hans/features/mixed_precision_training.md +++ /dev/null @@ -1,345 +0,0 @@ -# 自动混合精度训练 (旧版本) - -作者: Chuanrui Wang, Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) - -**相关论文** -- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) - - -## 引言 - -AMP 代表自动混合精度训练。 -在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: - -1. torch.cuda.amp -2. apex.amp -3. naive amp - - -| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16范围 | -| ----------- | ----------------------- | ------------------------- | ----------- | -| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至fp16 | -| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至fp16 | - -前两个依赖于 PyTorch (1.6及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现inf或nan。我们修改了torch amp实现,使其现在与张量并行兼容。 - -> ❌️ fp16与ZeRO配置不兼容 -> -> ⚠️ 流水并行目前仅支持naive amp - -我们建议使用 torch AMP,因为在不使用流水并行时,它通常比 NVIDIA AMP 提供更好的准确性。 - -## 目录 - -在本教程中,我们将介绍: - -1. AMP 介绍 -2. Colossal-AI 中的 AMP -3. 练习实例 - -## AMP 介绍 - -自动混合精度训练是混合 FP16 和 FP32 训练。 - -半精度浮点格式(FP16)具有较低的算法复杂度和较高的计算效率。此外,FP16 仅需要 FP32 所需的一半存储空间,并节省了内存和网络带宽,从而为大 batch size 和大模型提供了更多内存。 - -然而,还有其他操作,如缩减,需要 FP32 的动态范围,以避免数值溢出/下溢。因此,我们引入自动混合精度,尝试将每个操作与其相应的数据类型相匹配,这可以减少内存占用并提高训练效率。 - -
- -
AMP 示意图 (图片来自 PatrickStar 论文)
-
- -## Colossal-AI 中的 AMP - -我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。只需在配置文件中添加'fp16'配置即可使用 AMP。 - -```python -from colossalai.amp import AMP_TYPE - -# 使用 Torch AMP -fp16=dict( - mode = AMP_TYPE.TORCH -) - -# 使用 naive AMP -fp16=dict( - mode = AMP_TYPE.NAIVE -) - -# 使用 Nvidia Apex AMP -fp16=dict( - mode = AMP_TYPE.APEX -) - -``` - -> 这些是最低配置,完整配置将在后面的部分中说明 - -### AMP 模块化 - -AMP 模块设计为完全模块化,可以独立使用。如果你想在你的代码库中只使用 AMP 而不使用`colossalai.initialize`,你可以导入`colossalai.amp.convert_to_amp`。 - -```python -from colossalai.amp import AMP_TYPE - -# 使用torch amp的例子 -model, optimizer, criterion = colossalai.amp.convert_to_amp(model, - optimizer, - criterion, - AMP_TYPE.TORCH) -``` - -### Torch AMP 配置 - -```python -from colossalai.amp import AMP_TYPE - -fp16=dict( - mode=AMP_TYPE.TORCH, - - # 下列是grad scaler的默认值 - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True -) -``` - -可选参数: -- init_scale(float, optional, default=2.**16): 初始缩放因子; -- growth_factor(float, optional, default=2.0): 如果在``growth_interval``连续迭代过程中没有出现 inf/NaN 梯度,则在`update`中乘以比例系数; -- backoff_factor(float, optional, default=0.5): 如果在迭代中出现 inf/NaN 梯度,则在`update`中乘以比例系数; -- growth_interval(int, optional, default=2000): 在指定次数的连续迭代中,若没有出现 inf/NaN 梯度,则乘以``growth_factor``. -- enabled(bool, optional, default=True): ``False``则使梯度缩放无效,`step` 仅调用底层的 ``optimizer.step()``, 其他方法成为空操作。 - -### Apex AMP 配置 - -对于这种模式,我们依靠 Apex 实现混合精度训练。我们支持这个插件,因为它允许对混合精度的粒度进行更精细的控制。 -例如, O2 水平 (优化器水平2) 将保持 batch normalization 为 FP32。 - -如果你想了解更多细节,请参考 [Apex Documentation](https://nvidia.github.io/apex/)。 - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.APEX, - - # 下列是默认值 - enabled=True, - opt_level='O1', - cast_model_type=None, - patch_torch_functions=None, - keep_batchnorm_fp32=None, - master_weights=None, - loss_scale=None, - cast_model_outputs=None, - num_losses=1, - verbosity=1, - min_loss_scale=None, - max_loss_scale=16777216.0 -) -``` - -参数: -- enabled(bool, optional, default=True): False 会使所有 AMP 调用成为空操作, 程序将会像没有使用 AMP 一样运行。 - -- opt_level(str, optional, default="O1" ): 纯精度或混合精度优化水平。可选值 “O0”, “O1”, “O2”, and “O3”, 详细解释见上方 Apex AMP 文档。 - -- num_losses(int, optional, default=1): 选择提前告知 AMP 您计划使用多少次损失/反向计算。 -当`amp.scale_loss`与 loss_id 参数一起使用时,使 AMP 在每次损失/反向计算时使用不同的损失比例,这可以提高稳定性。如果 num_losses 被设置为1,AMP 仍支持多次损失/反向计算,但对他们都使用同一个全局损失比例。 - -- verbosity(int, default=1): 设置为0抑制 AMP 相关输出。 - -- min_loss_scale(float, default=None): 为可通过动态损耗比例选择的损耗比例值设置下限。 -默认值“None”意味着不设置任何下限。如果不使用动态损耗比例,则忽略 min_loss_scale 。 - -- max_loss_scale(float, default=2.**24 ): 为可通过动态损耗比例选择的损耗比例值设置上限。如果不使用动态损耗比例,则 max_loss_scale 被忽略. - -目前,管理纯精度或混合精度训练的幕后属性有以下几种: -cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. -一旦 opt_level 被确定,它们是可选的可覆盖属性 - -- cast_model_type: 将模型的参数和缓冲区强制转换为所需的类型。 -- patch_torch_functions: 补全所有的 Torch 函数和张量方法,以便在FP16中执行张量核心友好的操作,如 GEMMs 和卷积,以及在 FP32 中执行任何受益于 FP32 精度的操作。 -- keep_batchnorm_fp32: 为了提高精度并启用 cudnn batchnorm (这会提高性能),在 FP32 中保留 batchnorm 权重通常是有益的,即使模型的其余部分是 FP16。 -- master_weights: 保持 FP32 主权重以配合任何 FP16 模型权重。 FP32 主权重由优化器分级,以提高精度和捕捉小梯度。 -- loss_scale: 如果 loss_scale 是一个浮点数,则使用这个值作为静态(固定)的损失比例。如果 loss_scale 是字符串 "dynamic",则随着时间的推移自适应地调整损失比例。动态损失比例调整由 AMP 自动执行。 - - -### Naive AMP 配置 - -在 Naive AMP 模式中, 我们实现了混合精度训练,同时保持了与复杂张量和流水并行的兼容性。该 AMP 模式将所有操作转为 FP16 。下列代码块展示了该模式的`config.py`。 - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.NAIVE, - - # below are the default values - log_num_zeros_in_grad=False, - initial_scale=2 ** 32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2 -) -``` - -Naive AMP 的默认参数: -- log_num_zeros_in_grad(bool): 返回0值梯度的个数. -- initial_scale(int): gradient scaler 的初始值 -- growth_factor(int): loss scale 的增长率 -- backoff_factor(float): loss scale 的下降率 -- hysteresis(int): 动态 loss scaling 的延迟偏移 -- max_scale(int): loss scale 的最大允许值 -- verbose(bool): 如果被设为`True`,将打印调试信息 - -当使用`colossalai.initialize`时, 首先需要实例化一个模型、一个优化器和一个标准。将输出模型转换为内存消耗较小的 AMP 模型。如果您的输入模型已经太大,无法放置在 GPU 中,请使用`dtype=torch.float16`实例化你的模型。或者请尝试更小的模型,或尝试更多的并行化训练技术! - -## 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) -展现如何在 Colossal-AI 使用 AMP。在该例程中,我们使用 Torch AMP, 但提供的配置文件也适用于所有 AMP 模式. - -### 步骤 1. 创建配置文件 - -创建一个`config.py`文件并添加`fp16`配置. - -```python -# in config.py -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 - -fp16 = dict( - mode=AMP_TYPE.TORCH, -) - -clip_grad_norm = 1.0 -``` - -### 步骤 2. 在 train_with_engine.py 导入相关库 - -创建`train_with_engine.py`并导入必要依赖. 请记得通过命令`pip install timm scipy`安装`scipy`和`timm`。 - -```python -import os -import colossalai -import torch -from pathlib import Path -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import get_dataloader -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.nn.lr_scheduler import LinearWarmupLR -from timm.models import vit_base_patch16_224 -from torchvision import datasets, transforms - -``` - -### 步骤 3. 初始化分布式环境 - -我们需要初始化分布式环境。为了快速演示,我们使用`launch_from_torch`。你可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) -使用其他初始化方法。 - -```python -# 初始化分布式设置 -parser = colossalai.get_default_parser() -args = parser.parse_args() - -# launch from torch -colossalai.launch_from_torch(config=args.config) - -``` - -### 步骤 4. 创建训练组件 - -构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])` -在你的机器上设置路径。数据将会被自动下载到该路径。 - -```python -# build model - model = vit_base_patch16_224(drop_rate=0.1) - - # build dataloader - train_dataset = datasets.Caltech101( - root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(256), - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - Gray2RGB(), - transforms.Normalize([0.5, 0.5, 0.5], - [0.5, 0.5, 0.5]) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - - # build optimizer - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) - - # build loss - criterion = torch.nn.CrossEntropyLoss() - - # lr_scheduler - lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) -``` - -### 步骤 5. 插入 AMP - -调用 `colossalai.initialize` 将所有训练组件转为为FP16模式. - -```python -engine, train_dataloader, _, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, - ) -``` - -### 步骤 6. 使用 Engine 训练 - -使用Engine构建一个普通的训练循环 - -```python -engine.train() -for epoch in range(gpc.config.NUM_EPOCHS): - for img, label in enumerate(train_dataloader): - img = img.cuda() - label = label.cuda() - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() - lr_scheduler.step() -``` - -### 步骤 7. 启动训练脚本 - -使用下列命令启动训练脚本,你可以改变 `--nproc_per_node` 以使用不同数量的 GPU。 - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py -``` - diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md index 0354f92ee7ce..8e9f614a25af 100644 --- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -1,10 +1,9 @@ -# 自动混合精度训练 (新版本) +# 自动混合精度训练 作者: [Mingyan Jiang](https://github.com/jiangmingyan) **前置教程** -- [定义配置文件](../basics/define_your_config.md) - [booster 使用](../basics/booster_api.md) **相关论文** @@ -57,7 +56,7 @@ AMP 代表自动混合精度训练。 ## Colossal-AI 中的 AMP -我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数,我们现已支持 torch amp,apex amp, naive amp(现已移植 torch amp 至 booster,apex amp, naive amp 仍由`colossalai.initialize`方式启动,如您需使用,请[参考](./mixed_precision_training.md);后续将会拓展`bf16`,`pf8`的混合精度训练. +我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数;后续将会拓展`bf16`,`pf8`的混合精度训练. #### booster 启动方式 diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md index adb3fac3ab08..61290628588b 100644 --- a/docs/source/zh-Hans/features/zero_with_chunk.md +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -204,7 +204,7 @@ def main(): torch.cuda.synchronize() ``` -> ⚠️ 注意:如果你使用Gemini模块的话,请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation.md)。 +> ⚠️ 注意:如果你使用Gemini模块的话,请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation_with_booster.md)。 完整的例子代码可以在 [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 From 493a5efeab03ec0b9bf23dbf6e653fe2eb9c18ce Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 14:53:16 +0800 Subject: [PATCH 33/58] [doc] add shardformer doc to sidebar (#4768) --- docs/sidebars.json | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sidebars.json b/docs/sidebars.json index bf92e9755f4a..ce197a31e71b 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -37,6 +37,7 @@ "label": "Features", "collapsed": true, "items": [ + "features/shardformer", "features/mixed_precision_training_with_booster", "features/gradient_accumulation_with_booster", "features/gradient_clipping_with_booster", From 901ab1eedd405b09a8fe1a31e055a9c207568bdb Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 21 Sep 2023 16:23:59 +0800 Subject: [PATCH 34/58] [chat]: add lora merge weights config (#4766) * feat: modify lora merge weights fn * feat: add lora merge weights config --- applications/Chat/coati/models/lora.py | 70 +++++++++++-------- applications/Chat/examples/train_prompts.py | 7 ++ .../Chat/examples/train_reward_model.py | 8 +++ applications/Chat/examples/train_sft.py | 7 ++ 4 files changed, 61 insertions(+), 31 deletions(-) diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 2114913e107b..e9bd7b2ed8f0 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -1,4 +1,6 @@ +import dataclasses import math +import warnings from typing import Optional import loralib as lora @@ -7,6 +9,14 @@ import torch.nn.functional as F +@dataclasses.dataclass +class LoRAManager: + merge_weights: bool = False + + +LORA_MANAGER = LoRAManager() + + class LoraLinear(lora.LoRALayer, nn.Module): """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" @@ -17,13 +27,11 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + fan_in_fan_out: bool = False, ): nn.Module.__init__(self) - lora.LoRALayer.__init__( - self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights - ) + lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) self.weight = weight self.bias = bias @@ -53,31 +61,31 @@ def train(self, mode: bool = True): def T(w): return w.T if self.fan_in_fan_out else w - nn.Module.train(self, mode) - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0: - if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): - # FIXME(csric): temporary fix - self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) - self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) - self.reset_parameters() - else: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - self.merged = False - - def eval(self): - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Module.eval(self) - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, "lora_A") - delattr(self, "lora_B") - self.merged = True + self.training = mode + if LORA_MANAGER.merge_weights: + if mode and self.merged: + warnings.warn("Invoke module.train() would unmerge LoRA weights.") + raise NotImplementedError("LoRA unmerge is not tested.") + # Make sure that the weights are not merged + if self.r > 0: + if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): + # FIXME(csric): temporary fix + self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) + self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) + self.reset_parameters() + else: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + elif not mode and not self.merged: + warnings.warn("Invoke module.eval() would merge LoRA weights.") + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + delattr(self, "lora_A") + delattr(self, "lora_B") + self.merged = True + + return self def forward(self, x: torch.Tensor): def T(w): @@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: assert ( lora_rank <= linear.in_features ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" - lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) + lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) return lora_linear diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index de2a33263040..a8ab15eebfa5 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -192,6 +192,12 @@ def main(args): use_wandb=args.use_wandb, ) + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + actor.eval() # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks @@ -227,6 +233,7 @@ def main(args): parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=1e-7) parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.9) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index c9095b365884..c1be51f2f587 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -157,6 +157,13 @@ def train(args): log_dir=args.log_dir, use_wandb=args.use_wandb, ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks @@ -186,6 +193,7 @@ def train(args): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--log_dir", default="logs", type=str) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index a34661762258..4f36791be3cf 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -177,6 +177,12 @@ def train(args): use_wandb=args.use_wandb, ) + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks @@ -204,6 +210,7 @@ def train(args): parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--log_dir", default="logs", type=str) From 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 21 Sep 2023 16:30:23 +0800 Subject: [PATCH 35/58] [lazy] support torch 2.0 (#4763) * [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires grad --- .isort.cfg | 1 + colossalai/lazy/construction.py | 87 ++++++++++++++ colossalai/lazy/lazy_init.py | 207 ++++++++++++++++++-------------- tests/test_lazy/test_models.py | 8 +- tests/test_lazy/test_ops.py | 64 ++++++++++ 5 files changed, 273 insertions(+), 94 deletions(-) create mode 100644 colossalai/lazy/construction.py create mode 100644 tests/test_lazy/test_ops.py diff --git a/.isort.cfg b/.isort.cfg index 4f881c8b3dda..ccbf575fdbfa 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -4,3 +4,4 @@ multi_line_output=3 include_trailing_comma = true ignore_comments = true profile = black +honor_noqa = true diff --git a/colossalai/lazy/construction.py b/colossalai/lazy/construction.py new file mode 100644 index 000000000000..6764eaf774ab --- /dev/null +++ b/colossalai/lazy/construction.py @@ -0,0 +1,87 @@ +from contextlib import contextmanager +from typing import Callable, Dict, Tuple + +import torch + +__all__ = [ + "_LEGACY_TENSOR_CONSTRUCTOR", + "_NO_META_FACTORY", + "_NORMAL_FACTORY", + "ConstructorManager", +] + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_LEGACY_TENSOR_CONSTRUCTOR = { + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, +} + + +class ConstructorManager: + # function name: (new, old) + overwrites: Dict[str, Tuple[Callable, Callable]] = {} + changed: bool = False + + @staticmethod + def apply(overwrites: Dict[Callable, Callable]): + ConstructorManager.overwrites.clear() + ConstructorManager.overwrites.update(overwrites) + ConstructorManager.redo() + + @staticmethod + def undo(): + assert ConstructorManager.changed, "No constructor change to undo" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, old) + ConstructorManager.changed = False + + @staticmethod + def redo(): + assert not ConstructorManager.changed, "Constructor already changed" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, new) + ConstructorManager.changed = True + + @staticmethod + @contextmanager + def disable(): + enabled = ConstructorManager.changed + if enabled: + ConstructorManager.undo() + yield + if enabled: + ConstructorManager.redo() + + @staticmethod + def clear(): + if ConstructorManager.changed: + ConstructorManager.undo() + ConstructorManager.overwrites.clear() diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index ebaf2e1600fc..f29e997da495 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,17 +1,18 @@ from types import MethodType -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn +from packaging import version from torch import Tensor from torch.nn import Parameter from torch.utils._pytree import tree_map -from colossalai._analyzer._subclasses import MetaTensor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor import distribute_tensor -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.logging import get_dist_logger + +from .construction import ConstructorManager + +import colossalai._analyzer._subclasses._meta_registration # noqa # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -41,6 +42,9 @@ # These ops cannot be unwrapped using .data _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] +# These ops is not related to tensor value and should not be rerun +_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"] + _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, "DoubleTensor": torch.double, @@ -54,6 +58,20 @@ "BoolTensor": torch.bool, } +# These ops have at least one lazy tensor argument and maybe a scalar argument +# scalar value should be converted to meta tensor +# this is a hack for torch 2.0 +_EXPAND_SCALAR_OPS = [ + "where", + "clamp", + "clamp_min", + "clamp_max", + "clamp_", + "clamp_min_", + "clamp_max_", +] +_old_tensor_factory = torch.tensor + _EMPTY_DATA = torch.empty(0) @@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor): """ _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _meta_data: Optional[torch.Tensor] = None # shape, dtype, device _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None + _device: torch.device # fake device of mate tensor @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + # tips for torch 2.0: + # torch 2.0 disables torch dispatch for subclass of tensor + # MetaTensor is cannot be used + # Now lazy tensor contains device injection and meta tensor if concrete_data is not None: # some ops don't support meta backend and should have concrete data elem = concrete_data else: if meta_data is None: - device = kwargs.get("device", "cpu") - elem = func(*args, **{**kwargs, "device": "meta"}) - meta_data = MetaTensor(elem, device=device) - elem = meta_data._tensor + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + meta_data = func(*args, **{**kwargs, "device": "meta"}) + elem = meta_data # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data + return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + self._device = torch.device(kwargs.get("device", None) or "cpu") if func.__name__ in _NORMAL_FACTORY: kwargs = {**kwargs, "device": LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + @property + def device(self) -> torch.device: + return self._materialized_data.device if self._materialized_data is not None else self._device + + def __repr__(self): + return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). @@ -183,20 +215,6 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: - """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. - - Args: - layout (Layout): Distribution layout. - - Returns: - torch.Tensor: The distributed tensor (self). - """ - target = self._materialize_data() - self.clean() - local_tensor = distribute_tensor(target, device_mesh, sharding_spec) - return _convert_cls(self, local_tensor) - def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" delattr(self, "_factory_method") @@ -299,45 +317,80 @@ def unwrap(x): # for early materialized tensor, use its materialized data directly return x._materialized_data if is_change_meta_op else x._materialized_data.data t = x if is_inplace else x.clone() - t._op_buffer.append((func, args, kwargs)) + if func.__name__ not in _NO_RERUN_OPS: + t._op_buffer.append((func, args, kwargs)) meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta + elif ( + version.parse(torch.__version__) >= version.parse("2.0.0") + and func.__name__ in _EXPAND_SCALAR_OPS + and not isinstance(x, torch.Tensor) + ): + return _old_tensor_factory(x, device="meta") return x def wrap(y, i=None): - if isinstance(y, MetaTensor): - if y in meta_to_lazy: - # inplace op, just return origin lazy tensor - return meta_to_lazy[y] + if isinstance(y, torch.Tensor): + if y.is_meta: + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y else: - # out of place op, create new lazy tensor - fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] - fn.__name__ = func.__name__ - lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) - return lazy_y - elif type(y) is Tensor: - # for early materialized tensor - return LazyTensor(lambda: None, concrete_data=y) + # for early materialized tensor + return LazyTensor(lambda: None, concrete_data=y) return y cls._pre_op_fn() - o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) if isinstance(o, (tuple, list)): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return wrap(o) - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip + def to(self, *args, **kwargs) -> torch.Tensor: + if self._materialized_data is not None: + return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) + + device = None + + def replace(x): + nonlocal device + if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): + device = x + return torch.device("meta") + return x + + meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + + if meta_data is self._meta_data and device == self.device: + return self + + def factory_fn(t: torch.Tensor, **kw): + return t.to(*args, **kwargs) + + return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) + + def cpu(self, memory_format: torch.memory_format = torch.preserve_format): + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format): + device = torch.device(device or "cuda") + return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format) def clone(self) -> "LazyTensor": - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return new_tensor.clone() + return t.clone() - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) return target @@ -353,17 +406,16 @@ def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return _copy_tensor(new_tensor, new_tensor.requires_grad) + return _copy_tensor(t, t.requires_grad) if self._materialized_data is not None: # self is early materialized copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) if isinstance(self, Parameter): # hack isinstance check of parameter @@ -394,14 +446,12 @@ def data(self, other: "LazyTensor"): if other is self: return - self._op_buffer.append(other._factory_method) - def replace(x): if x is other: return self return x - for func, args, kwargs in other._op_buffer: + for func, args, kwargs in [other._factory_method, *other._op_buffer]: self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) def tolist(self) -> list: @@ -455,7 +505,6 @@ def __init__( default_device: Optional[Union[torch.device, str, int]] = None, ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor - self.overrides = {} self.tensor_cls = tensor_cls self.old_default_device = LazyTensor.default_device self.default_device = default_device @@ -478,7 +527,9 @@ def wrap_factory_like_method(orig_target, target): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + return self.tensor_cls( + orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs + ) return wrapper, target @@ -513,13 +564,13 @@ def wrapper(*args, **kwargs): return wrapper, target - self.overrides = { + overrides = { target: wrap_factory_method(getattr(torch, target)) for target in _NORMAL_FACTORY if callable(getattr(torch, target, None)) } - self.overrides.update( + overrides.update( { target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) for target in _NORMAL_FACTORY @@ -527,7 +578,7 @@ def wrapper(*args, **kwargs): } ) - self.overrides.update( + overrides.update( { target: wrap_legacy_constructor(getattr(torch, target), dtype) for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() @@ -535,7 +586,7 @@ def wrapper(*args, **kwargs): } ) - self.overrides.update( + overrides.update( { target: wrap_no_meta_factory(getattr(torch, target)) for target in _NO_META_FACTORY @@ -543,14 +594,12 @@ def wrapper(*args, **kwargs): } ) - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, wrapper) + ConstructorManager.apply(overrides) def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, orig) + ConstructorManager.clear() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: @@ -566,23 +615,6 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) - @staticmethod - def distribute( - module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False - ) -> nn.Module: - """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. - - Args: - module (nn.Module): Target ``nn.Module`` - layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. - verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. - """ - - def apply_fn(name: str, p: LazyTensor): - p.distribute(device_mesh, sharding_spec_dict[name]) - - return _apply_to_lazy_module(module, apply_fn, verbose) - def _apply_to_lazy_module( module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False @@ -622,20 +654,17 @@ def _apply_to_lazy_module( if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") - _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") - _print_rank_0( - f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" + logger = get_dist_logger() + logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0]) + logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0]) + logger.info( + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%", + ranks=[0], ) return module -def _print_rank_0(*args, **kwargs): - if not dist.is_initialized() or dist.get_rank() == 0: - print(*args, **kwargs) - - def _is_int_tuple(args) -> bool: if not isinstance(args, tuple): return False diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 978cf06b55a0..a1b5763d4cd8 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if ( - name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") - or name.startswith("transformers_llama") - or name.startswith(("transformers_vit", "transformers_blip2")) + if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( + ("transformers_vit", "transformers_blip2") ): continue check_lazy_init(entry, verbose=True, default_device=default_device) if __name__ == "__main__": - test_torchvision_models_lazy_init("torchvision") + test_torchvision_models_lazy_init("transformers", "cpu") diff --git a/tests/test_lazy/test_ops.py b/tests/test_lazy/test_ops.py new file mode 100644 index 000000000000..e6b936198547 --- /dev/null +++ b/tests/test_lazy/test_ops.py @@ -0,0 +1,64 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from lazy_init_utils import SUPPORT_LAZY +from torch.nn import Parameter + +from colossalai.lazy import LazyInitContext + + +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +def test_lazy_ops(): + with LazyInitContext(): + x = torch.rand(2, 3) + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + x.requires_grad is False + y = x.cuda() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + assert x.cpu() is x + p = Parameter(torch.empty(2, 3)) + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + x.materialize() + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + assert x.requires_grad is False + y.materialize() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + p.materialize() + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + + with LazyInitContext(): + x = torch.empty(2, 3) + x.uniform_() + x.materialize() + assert tuple(x.shape) == (2, 3) + + with LazyInitContext(): + model = nn.Linear(3, 4) + model = model.cuda() + model_copied = copy.deepcopy(model) + LazyInitContext.materialize(model) + assert model.weight.device.type == "cuda" + assert model.bias.device.type == "cuda" + LazyInitContext.materialize(model_copied) + assert model_copied.weight.device.type == "cuda" + assert model_copied.bias.device.type == "cuda" + assert torch.equal(model.weight, model_copied.weight) + assert torch.equal(model.bias, model_copied.bias) + + +if __name__ == "__main__": + test_lazy_ops() From 1e0e080837478e95bc2d835c58ccd025a0013c00 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Fri, 22 Sep 2023 10:50:47 +0800 Subject: [PATCH 36/58] [bug] Fix the version check bug in colossalai run when generating the cmd. (#4713) * Fix the version check bug in colossalai run when generating the cmd. * polish code --- colossalai/cli/launcher/run.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 7ca8ee90386c..88f70f02ec27 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -156,7 +156,8 @@ def _arg_dict_to_list(arg_dict): torch_version = version.parse(torch.__version__) assert torch_version.major >= 1 - if torch_version.minor < 9: + if torch_version.major == 1 and torch_version.minor < 9: + # torch distributed launch cmd with torch < 1.9 cmd = [ sys.executable, "-m", @@ -177,7 +178,8 @@ def _arg_dict_to_list(arg_dict): value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor == 9: + # torch distributed launch cmd with torch == 1.9 cmd = [ sys.executable, "-m", @@ -187,6 +189,7 @@ def _arg_dict_to_list(arg_dict): f"--node_rank={node_rank}", ] else: + # torch distributed launch cmd with torch > 1.9 cmd = [ "torchrun", f"--nproc_per_node={nproc_per_node}", From 946ab56c486480875020252ad65f9a4618fb9a16 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 22 Sep 2023 11:02:50 +0800 Subject: [PATCH 37/58] [feature] add gptq for inference (#4754) * [gptq] add gptq kernel (#4416) * add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test * [gptq] faster gptq cuda kernel (#4494) * [skip ci] add cuda kernels * add license * [skip ci] fix max_input_len * format files & change test size * [skip ci] * [gptq] add gptq tensor parallel (#4538) * add gptq tensor parallel * add gptq tp * delete print * add test gptq check * add test auto gptq check * [gptq] combine gptq and kv cache manager (#4706) * combine gptq and kv cache manager * add init bits * delete useless code * add model path * delete usless print and update test * delete usless import * move option gptq to shard config * change replace linear to shardformer * update bloom policy * delete useless code * fix import bug and delete uselss code * change colossalai/gptq to colossalai/quant/gptq * update import linear for tests * delete useless code and mv gptq_kernel to kernel directory * fix triton kernel * add triton import --- LICENSE | 49 ++ colossalai/inference/quant/gptq/__init__.py | 4 + .../inference/quant/gptq/cai_gptq/__init__.py | 13 + .../quant/gptq/cai_gptq/cai_quant_linear.py | 354 ++++++++++++ .../inference/quant/gptq/cai_gptq/gptq_op.py | 58 ++ .../inference/tensor_parallel/engine.py | 56 ++ .../tensor_parallel/policies/bloom.py | 32 ++ .../tensor_parallel/policies/llama.py | 51 ++ .../cuda_native/csrc/gptq/column_remap.cu | 63 ++ .../cuda_native/csrc/gptq/column_remap.cuh | 19 + .../cuda_native/csrc/gptq/cu_compat.cuh | 58 ++ .../cuda_native/csrc/gptq/cuda_buffers.cu | 75 +++ .../cuda_native/csrc/gptq/cuda_buffers.cuh | 55 ++ .../cuda_native/csrc/gptq/hip_compat.cuh | 49 ++ .../cuda_native/csrc/gptq/linear_gptq.cpp | 254 ++++++++ .../kernel/cuda_native/csrc/gptq/matrix.cuh | 294 ++++++++++ .../kernel/cuda_native/csrc/gptq/q4_matmul.cu | 260 +++++++++ .../cuda_native/csrc/gptq/q4_matmul.cuh | 43 ++ .../kernel/cuda_native/csrc/gptq/q4_matrix.cu | 225 ++++++++ .../cuda_native/csrc/gptq/q4_matrix.cuh | 53 ++ .../kernel/cuda_native/csrc/gptq/tuning.h | 13 + .../kernel/cuda_native/csrc/gptq/util.cuh | 33 ++ colossalai/kernel/triton/__init__.py | 2 + colossalai/kernel/triton/gptq_triton.py | 541 ++++++++++++++++++ colossalai/shardformer/shard/shard_config.py | 7 +- examples/inference/gptq_bloom.py | 123 ++++ examples/inference/gptq_llama.py | 135 +++++ op_builder/gptq.py | 52 ++ requirements/requirements-test.txt | 1 + tests/test_gptq/test_gptq_linear.py | 150 +++++ 30 files changed, 3120 insertions(+), 2 deletions(-) create mode 100644 colossalai/inference/quant/gptq/__init__.py create mode 100644 colossalai/inference/quant/gptq/cai_gptq/__init__.py create mode 100644 colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py create mode 100644 colossalai/inference/quant/gptq/cai_gptq/gptq_op.py create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/tuning.h create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/util.cuh create mode 100644 colossalai/kernel/triton/gptq_triton.py create mode 100644 examples/inference/gptq_bloom.py create mode 100644 examples/inference/gptq_llama.py create mode 100644 op_builder/gptq.py create mode 100644 tests/test_gptq/test_gptq_linear.py diff --git a/LICENSE b/LICENSE index 06629068faa5..59d456c5b8a1 100644 --- a/LICENSE +++ b/LICENSE @@ -428,3 +428,52 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + ---------------- LICENSE FOR AutoGPTQ ---------------- + + From AutoGPTQ: + + MIT License + + Copyright (c) 2023 潘其威(William) + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ---------------- LICENSE FOR exllama ---------------- + + From exllama: + + MIT License + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py new file mode 100644 index 000000000000..c035f397923a --- /dev/null +++ b/colossalai/inference/quant/gptq/__init__.py @@ -0,0 +1,4 @@ +from .cai_gptq import HAS_AUTO_GPTQ + +if HAS_AUTO_GPTQ: + from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py new file mode 100644 index 000000000000..de57f2d8cfee --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/__init__.py @@ -0,0 +1,13 @@ +import warnings + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 000000000000..ca12c34ed958 --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,354 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import math +import warnings +from typing import List, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import ParallelModule + +from .gptq_op import CaiGPTQLinearOp + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + +class CaiQuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + if row_split: + self.register_buffer( + 'g_idx', + torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], + dtype=torch.int32)) + else: + self.register_buffer('g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + + self.q4 = None + self.empty_tensor = torch.empty((1, 1), device="meta") + self.tp_size = tp_size + self.tp_rank = tp_rank + self.row_split = row_split + + def pack(self, linear, scales, zeros, g_idx=None): + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + wn = 8 + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, + None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() #.to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros + self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx.to(g_idx.device), g_idx): + self.g_idx = None + else: + self.g_idx = g_idx + + def init_q4(self): + assert self.qweight.device.type == "cuda" + self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + + if self.g_idx is not None: + g_idx = self.g_idx.to("cpu") + else: + g_idx = self.empty_tensor + + self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) + torch.cuda.synchronize() + + def forward(self, x): + outshape = x.shape[:-1] + (self.outfeatures,) + + if HAS_GPTQ_CUDA and self.bits == 4: + + if self.q4 is None: + self.init_q4() + + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) + gptq_cuda.q4_matmul(x.half(), self.q4, output) + if self.bias is not None and (not self.row_split or self.tp_size == 1): + output.add_(self.bias) + else: + if self.bias is not None and (not self.row_split or self.tp_size == 1): + bias = self.bias + else: + bias = None + output = self.gptq_linear( + x, + self.qweight, + self.scales, + self.qzeros, + g_idx=self.g_idx, + bias=bias, + ) + return output.view(outshape) + + +def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) + qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) + scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) + g_idx = gptq_linear.g_idx + if gptq_linear.bias is not None: + bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) + + cai_split_out_features = cai_linear.outfeatures // split_num + zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num + + for i in range(split_num): + cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + cai_linear.qzeros[:, i * zero_split_block:(i + 1) * + zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] + cai_linear.scales[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + if cai_linear.bias is not None: + cai_linear.bias[i * cai_split_out_features:(i + 1) * + cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + + cai_linear.g_idx.copy_(g_idx) + + +def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) + qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) + scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) + g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) + + cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num + zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num + idx_split_features = cai_linear.infeatures // split_num + + for i in range(split_num): + cai_linear.qweight[i * cai_split_in_features:(i + 1) * + cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * + cai_split_in_features, :] + cai_linear.qzeros[i * zero_split_block:(i + 1) * + zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.scales[i * zero_split_block:(i + 1) * + zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.g_idx[i * idx_split_features:(i + 1) * + idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * + idx_split_features] + if cai_linear.bias is not None: + cai_linear.bias.copy_(gptq_linear.bias) + + +class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = RowCaiQuantLinear(module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) + linear_1d.process_group = process_group + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.bias is not None: + output.add_(self.bias) + return output + + +class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = ColCaiQuantLinear(module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank) + linear_1d.process_group = process_group + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py new file mode 100644 index 000000000000..a8902eb35cd0 --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,58 @@ +import torch + +from colossalai.kernel.triton import gptq_fused_linear_triton + + +class CaiGPTQLinearOp(torch.nn.Module): + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, + act_type=0, + bias: torch.Tensor = None, + residual: torch.Tensor = None, + qkv_fused=False, + ): + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + out = gptq_fused_linear_triton( + x, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.bits, + self.maxq, + self.group_size, + qkv_fused, + add_bias, + add_residual, + act_type=act_type, + g_idx=g_idx, + ) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + return out diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 1335f13d66b8..29b5d6117ffd 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,6 +1,7 @@ from typing import Any, Callable, List, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig @@ -68,6 +69,13 @@ def __init__( self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.bits = -1 + self.use_act_order = False + self.shard_config = shard_config self.model = None # optimize the original model by sharding with ShardFormer @@ -81,6 +89,50 @@ def _init_manager(self) -> None: self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num ) + def _post_init_gptq_buffer(self, model: nn.Module) -> None: + from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False + try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True + except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, + submodule.outfeatures) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), + dtype=torch.float16, + device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, + self.gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + def _optimize_model(self, model: nn.Module) -> None: """ Optimize the original model by sharding with ShardFormer. @@ -129,6 +181,10 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) + + if self.shard_config.inference_gptq: + self._post_init_gptq_buffer(model) + self.model = self.model.cuda() @property diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 2d18a3922c1e..3d6df2097000 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -3,6 +3,9 @@ import torch from torch.nn import LayerNorm +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from ..modeling.bloom import BloomInferenceForwards @@ -35,6 +38,35 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel policy = super().module_policy() + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 3}), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + ]) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 9bbb547dbcae..eaaadadd1f88 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -3,6 +3,8 @@ import torch from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -34,6 +36,55 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}, + ) + ], + ) + self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu new file mode 100644 index 000000000000..2b1b366b1c02 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu @@ -0,0 +1,63 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + if (x_column >= x_width) return; + //if (x_row >= x_height) return; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh new file mode 100644 index 000000000000..6571c17d6fd5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh new file mode 100644 index 000000000000..c5258813e147 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu new file mode 100644 index 000000000000..4416027c8387 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu @@ -0,0 +1,75 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state_size(_temp_state_size), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state_size, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh new file mode 100644 index 000000000000..0bf2057c665c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh @@ -0,0 +1,55 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + int temp_state_size; + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh new file mode 100644 index 000000000000..5cd2e8553ef6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh @@ -0,0 +1,49 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Workaround for hipify_python using rocblas instead of hipblas. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp new file mode 100644 index 000000000000..bcc0e43901de --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -0,0 +1,254 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "q4_matrix.cuh" +#include "q4_matmul.cuh" +#include "column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh new file mode 100644 index 000000000000..2fd5ab0b36cd --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu new file mode 100644 index 000000000000..f47daeb0e877 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu @@ -0,0 +1,260 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "util.cuh" +#include "matrix.cuh" +#include "cu_compat.cuh" +#include "cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; + const float beta = no_zero ? 1.0f : 0.0f; + cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, + x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +#else + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); +#endif +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh new file mode 100644 index 000000000000..09f3e1a63362 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh @@ -0,0 +1,43 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "tuning.h" + +// Workaround for hipify_python using rocblas instead of hipblas. +#if defined(USE_ROCM) +#include +#define rocblas_handle hipblasHandle_t +#endif + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu new file mode 100644 index 000000000000..9c61143f565e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -0,0 +1,225 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "util.cuh" +#include "matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks + ( + (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), + height / 8, + 1 + ); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh new file mode 100644 index 000000000000..50cb72a41518 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h new file mode 100644 index 000000000000..770ca46aa7c8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh new file mode 100644 index 000000000000..7b397573214b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index bc68a07e6fba..87ea9cf6536e 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -6,6 +6,7 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm + from .gptq_triton import gptq_fused_linear_triton from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd from .softmax import softmax @@ -20,6 +21,7 @@ "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd", + "gptq_fused_linear_triton", ] except ImportError: diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py new file mode 100644 index 000000000000..cf4ef183a59d --- /dev/null +++ b/colossalai/kernel/triton/gptq_triton.py @@ -0,0 +1,541 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import torch +import triton +import triton.language as tl +from auto_gptq.nn_modules.triton_utils import custom_autotune + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@custom_autotune.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + # if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@custom_autotune.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_idx_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + idx_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_ptrs = idx_ptr + offs_k + g_idx = tl.load(g_ptrs) + # tl.device_print("gidx, ", g_idx) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_fused_linear_triton( + input, + qweight, + scales, + qzeros, + bias, + residual, + bits, + maxq, + gptq_group_size, + qkv_fused, + add_bias, + add_residual, + g_idx=None, + act_type=0, +): + # print("gptq fused ", qkv_fused, add_bias, add_residual) + assert input.is_cuda, "input is not in cuda" + assert qweight.is_cuda, "qweight is not in cuda" + assert scales.is_cuda, "scales is not in cuda" + assert qzeros.is_cuda, "qzeros is not in cuda" + + with torch.cuda.device(input.device): + if qkv_fused: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]) + * 3, + ) + output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16) + else: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) + if g_idx is None: + cai_gptq_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + else: + cai_gptq_idx_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + if qkv_fused: + return output.view(3, input.shape[0], qweight.shape[1]) + else: + return output diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 6935288130c9..a285874d218b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -32,10 +32,13 @@ class ShardConfig: enable_fused_normalization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False - enable_sequence_parallelism: bool = False - enable_sequence_overlap: bool = False enable_all_optimization: bool = False inference_only: bool = False + inference_gptq: bool = False + enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False + # pipeline_parallel_size: int + # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @property diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py new file mode 100644 index 000000000000..43e118cc0aa5 --- /dev/null +++ b/examples/inference/gptq_bloom.py @@ -0,0 +1,123 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def bench_bloom(args): + + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir) + tokenizer.pad_token = tokenizer.eos_token + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + model = model.half() + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + # prepare data for generation + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)) + } + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + iters = 10 + times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_bloom(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + bench_bloom(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(args): + spawn(check_bloom, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_bloom(args) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py new file mode 100644 index 000000000000..818ae0035e87 --- /dev/null +++ b/examples/inference/gptq_llama.py @@ -0,0 +1,135 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from torch import distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline + +import colossalai +from colossalai.gptq import CaiQuantLinear +from colossalai.gptq.gptq_tp import replace_autogptq_linear +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def run_llama_test(args): + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + init_to_get_rotary(model.model.model, base=10000) + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_llama(args) diff --git a/op_builder/gptq.py b/op_builder/gptq.py new file mode 100644 index 000000000000..012cf0f8a78d --- /dev/null +++ b/op_builder/gptq.py @@ -0,0 +1,52 @@ +import os +import torch +import re + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +class GPTQBuilder(Builder): + + NAME = "cu_gptq" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" + + def __init__(self): + super().__init__(name=GPTQBuilder.NAME, + prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'gptq/linear_gptq.cpp', + 'gptq/column_remap.cu', + 'gptq/cuda_buffers.cu', + 'gptq/q4_matmul.cu', + 'gptq/q4_matrix.cu' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-v', + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17" + ] + + + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 80: + extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) \ No newline at end of file diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 53f0f958e297..467f83610eb0 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,3 +18,4 @@ SentencePiece ninja flash_attn==2.0.5 datasets +#auto-gptq now not support torch1.12 diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py new file mode 100644 index 000000000000..9b650aa78112 --- /dev/null +++ b/tests/test_gptq/test_gptq_linear.py @@ -0,0 +1,150 @@ +import math +import time + +import numpy as np +import pytest +import torch +import torch.nn as nn +import transformers +from packaging import version + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from auto_gptq.modeling._utils import autogptq_post_init + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + from exllama_kernels import prepare_buffers, set_tuning_params + + from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True +except: + HAS_AUTO_GPTQ = False + print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ") + +import warnings + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +max_inner_outer_dim = 1 +max_input_len = 1 +max_dq_buffer_size = 1 +gptq_temp_dq_buffer = None +gptq_temp_state_buffer = None + + +def init_buffer(cai_linear, use_act_order=False): + global max_dq_buffer_size + global max_input_len + global max_dq_buffer_size + global max_inner_outer_dim + global gptq_temp_dq_buffer + global gptq_temp_state_buffer + + max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8) + + if use_act_order: + max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures) + + if use_act_order: + max_input_len = 4096 + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +def test_gptq_linear(): + + infeature = 1024 + outfeature = 1024 + group_size = 128 + wbits = 4 + + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device()) + + device = torch.device("cuda:0") + + linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits) + + linear = linear_class( + bits=4, + group_size=group_size, + infeatures=infeature, + outfeatures=outfeature, + bias=False, + ) + + torch.manual_seed(42) + + linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) + linear.scales = linear.scales + 0.002 + + linear = linear.to(device) + + cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True) + cai_linear.qweight.data.copy_(linear.qweight) + cai_linear.scales = cai_linear.scales + 0.002 + cai_linear = cai_linear.to(device) + + linear = autogptq_post_init(linear, use_act_order=False) + + max_inner_outer_dim = max(infeature, outfeature) + max_dq_buffer_size = linear.infeatures * linear.outfeatures + max_input_len = 2048 + buffers = { + "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + } + + prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + with torch.no_grad(): + gptq_out = linear(inps) + batch_gptq_out = linear(batch_inps) + torch.cuda.synchronize() + cai_out = cai_linear(inps) + torch.cuda.synchronize() + + batch_cai_out = cai_linear(batch_inps) + torch.cuda.synchronize() + + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) + + +if __name__ == "__main__": + + test_gptq_linear() From ce7ade3882680ddc18a43375a71adaed194c6da4 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:12:50 +0800 Subject: [PATCH 38/58] [inference] chatglm2 infer demo (#4724) * add chatglm2 * add * gather needed kernels * fix some bugs * finish context forward * finish context stage * fix * add * pause * add * fix bugs * finish chatglm * fix bug * change some logic * fix bugs * change some logics * add * add * add * fix * fix tests * fix --- .../inference/tensor_parallel/engine.py | 35 +- .../tensor_parallel/modeling/__init__.py | 5 +- .../tensor_parallel/modeling/_utils.py | 10 + .../tensor_parallel/modeling/chatglm2.py | 540 ++++++++++++++++++ .../tensor_parallel/modeling/llama.py | 2 +- .../tensor_parallel/policies/__init__.py | 3 +- .../tensor_parallel/policies/chatglm2.py | 77 +++ colossalai/kernel/triton/context_attention.py | 326 ++++++++++- .../kernel/triton/rotary_embedding_kernel.py | 105 ++++ .../kernel/triton/token_attention_kernel.py | 437 ++++++++++++++ .../modeling/chatglm2_6b/modeling_chatglm.py | 4 - .../shardformer/policies/auto_policy.py | 9 +- tests/kit/model_zoo/transformers/chatglm2.py | 15 + tests/test_infer/test_chatglm2_infer.py | 73 +++ .../triton/test_llama2_token_attn.py | 65 +++ 15 files changed, 1692 insertions(+), 14 deletions(-) create mode 100644 colossalai/inference/tensor_parallel/modeling/_utils.py create mode 100644 colossalai/inference/tensor_parallel/modeling/chatglm2.py create mode 100644 colossalai/inference/tensor_parallel/policies/chatglm2.py create mode 100644 tests/test_infer/test_chatglm2_infer.py create mode 100644 tests/test_infer_ops/triton/test_llama2_token_attn.py diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 29b5d6117ffd..d5ef37fee420 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -16,7 +16,13 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"] +_supported_models = [ + "LlamaForCausalLM", + "LlamaModel", + "BloomForCausalLM", + "ChatGLMModel", + "ChatGLMForConditionalGeneration", +] class TPInferEngine: @@ -64,7 +70,13 @@ def __init__( self.head_dim = model.config.hidden_size // model.config.num_attention_heads self.head_num = model.config.num_attention_heads - self.layer_num = model.config.num_hidden_layers + num_hidden_layers = ( + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers + ) + self.layer_num = num_hidden_layers + self.multi_query_group_num = ( + model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0 + ) self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -85,9 +97,22 @@ def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads - self.cache_manager = MemoryManager( - self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num - ) + if self.multi_query_group_num: + # NOTE the logic of MQA tensor parallelism should be specified. + assert ( + self.multi_query_group_num % self.tp_size == 0 + ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" + self.cache_manager = MemoryManager( + self.max_total_token_num, + self.dtype, + self.multi_query_group_num // self.tp_size, + self.head_dim, + self.layer_num, + ) + else: + self.cache_manager = MemoryManager( + self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num + ) def _post_init_gptq_buffer(self, model: nn.Module) -> None: from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 27cec5452ece..279b54065eed 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,4 +1,7 @@ +import _utils + from .bloom import BloomInferenceForwards +from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards -__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"] +__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py new file mode 100644 index 000000000000..cee418707617 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -0,0 +1,10 @@ +""" +Utils for model inference +""" +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py new file mode 100644 index 000000000000..4b1bc601f436 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -0,0 +1,540 @@ +import os +from typing import Optional, Tuple + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd +from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, + split_tensor_along_last_dim, +) + +from ._utils import copy_kv_to_mem_cache + + +# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py +def _init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + except: + pass + n_elem = self.config.head_dim_ // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def get_masks(self, input_ids, past_length, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +class ChatGLM2InferenceForwards: + """ + This class holds forwards for Chatglm2 inference. + We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. + """ + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + infer_state = self.infer_state + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length + past_key_values_length + infer_state.seq_length_with_past = seq_length_with_past + + # prefill stage at first + if use_cache and seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + # related to rotary embedding + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = get_masks( + self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask + ) + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + infer_state=infer_state, + ) + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + infer_state.cache_manager.past_key_values_length += seq_length + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def chatglm_encoder_forward( + self: GLMTransformer, + hidden_states, + attention_mask, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ): + hidden_states = hidden_states.transpose(0, 1).contiguous() + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + + infer_state.decode_layer_id = 0 + for index in range(self.num_layers): + layer = self.layers[index] + + layer_ret = layer( + hidden_states, + attention_mask, + kv_cache=kv_caches[index], + use_cache=use_cache, + infer_state=infer_state, + ) + + infer_state.decode_layer_id += 1 + + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + hidden_states = hidden_states.transpose(0, 1).contiguous() + + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + @staticmethod + def chatglm_glmblock_forward( + self: GLMBlock, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + return output, kv_cache + + @staticmethod + def chatglm_flash_attn_kvcache_forward( + self: SelfAttention, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + assert use_cache is True, "use_cache should be set to True using this chatglm attention" + # hidden_states: original :[sq, b, h] --> this [b, sq, h] + batch_size = hidden_states.shape[0] + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + cos, sin = infer_state.position_cos, infer_state.position_sin + + Llama2Forwards.rotary_emb_fwd( + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin + ) + if self.multi_query_attention: + Llama2Forwards.rotary_emb_fwd( + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + else: + Llama2Forwards.rotary_emb_fwd( + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 + query_layer = query_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + if infer_state.is_context_stage: + # first token generation: + # copy key and value calculated in current step to memory manager + + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + + # NOTE: no bug in context attn fwd (del it ) + llama2_context_attn_fwd( + query_layer, + key_layer, + value_layer, + attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + infer_state.start_loc, + infer_state.seq_len, + infer_state.seq_length_with_past, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_layer) + cache_v.copy_(value_layer) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + + # ================================== + # core attention computation is replaced by triton kernel + # ================================== + Llama2TokenAttentionForwards.token_attn( + query_layer, + cache_k, + cache_v, + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + # print('after attention',torch.isnan(attn_output).any()) + + # ================= + # Output:[b,sq, h] + # ================= + + output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size) + return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 4795162f1980..64d6e947e924 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -100,7 +100,7 @@ def llama_model_forward( # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: - # NOTE assuem prefill stage + # NOTE assume prefill stage # allocate memory block infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py index fcb1b6a3bd8f..776c4e850565 100644 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -1,4 +1,5 @@ from .bloom import BloomModelInferPolicy +from .chatglm2 import ChatGLM2InferPolicy from .llama import LlamaModelInferPolicy -__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"] +__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py new file mode 100644 index 000000000000..cb223370a65d --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -0,0 +1,77 @@ +from functools import partial + +import torch + +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, +) +# import colossalai +from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy + +from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary + +try: + from colossalai.kernel.triton.rms_norm import rmsnorm_forward + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +class ChatGLM2InferPolicy(ChatGLMModelPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward + method_replacement = {'forward': model_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) + + encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward + method_replacement = {'forward': encoder_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=GLMTransformer) + + encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward + method_replacement = {'forward': encoder_layer_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) + + attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward + method_replacement = {'forward': attn_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=SelfAttention) + + # for rmsnorm and others, we need to check the shape + return policy + + def postprocess(self): + _init_to_get_rotary(self.model) + return self.model + + +class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward + method_replacement = {'forward': partial(model_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=ChatGLMForConditionalGeneration) + return policy + + def postprocess(self): + return super().postprocess() diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index dac95bfb14ae..01d54566483a 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -11,7 +11,6 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") - if HAS_TRITON: """ this function is modified from @@ -240,3 +239,328 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_stages=1, ) return + + @triton.jit + def _fwd_kernel_latest( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @triton.jit + def _fwd_kernel_old( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + # t_ptrs = TMP + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + return + + @torch.no_grad() + def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel_latest[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + elif triton.__version__ == "2.0.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _fwd_kernel_old[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py index eb43fab7935c..fd74ba817551 100644 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -105,3 +105,108 @@ def rotary_embedding_fwd(q, cos, sin): num_stages=1, ) return + + +class Llama2Forwards: + @staticmethod + @triton.jit + def _rotary_kernel( + Q, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + H, # N_CTX + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ): + cur_head_index = tl.program_id(0) + cur_seq_index = tl.program_id(1) + + cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store( + Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) + ) + tl.store( + Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) + ) + + return + + @staticmethod + @torch.no_grad() + def rotary_emb_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] // 2 + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + Llama2Forwards._rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 7d0f9708516a..c27394f0f9cf 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -402,3 +402,440 @@ def token_attention_fwd( prob = None return + + +class Llama2TokenAttentionForwards: + @staticmethod + @triton.jit + def _fwd_kernel( + Logics, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_b_loc_b, + stride_b_loc_s, + other_kv_index, # avoid nan information + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s + + v_ptrs = V + off_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index, + ) + + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head) + kv_group_num = logics.shape[0] // v.shape[1] + + num_warps = 1 + Llama2TokenAttentionForwards._fwd_kernel[grid]( + logics, + v, + o, + b_loc, + b_start_loc, + b_seq_len, + max_input_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + b_loc.stride(0), + b_loc.stride(1), + other_kv_index, + kv_group_num, + BLOCK_DMODEL=v.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_softmax( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) + return + + @staticmethod + @torch.no_grad() + def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): + BLOCK_SIZE = triton.next_power_of_2(max_input_len) + batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att1( + Q, + K, + sm_scale, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + Att_Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_end_index = max_input_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk**0.5) + + batch, head_num = B_Loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) + kv_group_num = q.shape[1] // k.shape[1] + + num_warps = 4 if Lk <= 64 else 8 + num_warps = 2 + + Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( + q, + k, + sm_scale, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + att_out, + B_Loc.stride(0), + B_Loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att2( + Prob, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, # B_Start_Loc cumsum of input lens if continuous + stride_b_loc_b, + stride_b_loc_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 + ) + v_loc = tl.load( + B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = B_Loc.shape[0], prob.shape[0] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + kv_group_num = prob.shape[0] // v.shape[1] + + Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( + prob, + v, + out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + B_Loc.stride(0), + B_Loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + # this is the interface of llama2 attn forward + @staticmethod + @torch.no_grad() + def token_attn( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index + ): + total_token_num = k.shape[0] + batch_size, head_num, head_dim = q.shape + calcu_shape1 = (batch_size, head_num, head_dim) + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + Llama2TokenAttentionForwards.token_att_fwd( + q, + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + if triton.__version__ == "2.0.0": + prob = torch.empty_like(att_m_tensor) + Llama2TokenAttentionForwards.token_softmax_fwd( + att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch + ) + att_m_tensor = None + + Llama2TokenAttentionForwards.token_att_fwd2( + prob, + v, + attn_out.view(calcu_shape1), + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + return + + elif triton.__version__ >= "2.1.0": + Llama2TokenAttentionForwards.token_softmax_reducev_fwd( + att_m_tensor, + v, + attn_out.view(calcu_shape1), + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + other_kv_index, + ) + else: + raise Exception("not support triton version") diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index 3a8d90ec7328..cbb25b5b1f4c 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) - self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads - self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: @@ -445,7 +443,6 @@ def forward( # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -541,7 +538,6 @@ def forward( # ================= # Output. [sq, b, h] # ================= - output = self.dense(context_layer) return output, kv_cache diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 3bea91ef94dc..f3587de15f86 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -164,6 +164,13 @@ class PolicyLocation: "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( file_name="bloom", class_name="BloomModelInferPolicy" ), + # ChatGLM2 + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + file_name="chatglm2", class_name="ChatGLM2InferPolicy" + ), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy" + ), } @@ -208,7 +215,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: policy = import_policy(policy_location, inference_only) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index 22885bec224a..f4369cb7d171 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -39,6 +39,21 @@ def data_gen_for_conditional_generation(): padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +infer_config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=128, + num_attention_heads=8, + multi_query_attention=True, + multi_query_group_num=2, + kv_channels=16, rmsnorm=True, original_rope=True, use_cache=True, diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py new file mode 100644 index 000000000000..699ba7b52fe0 --- /dev/null +++ b/tests/test_infer/test_chatglm2_infer.py @@ -0,0 +1,73 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from packaging import version +from transformers import AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo.transformers.chatglm2 import infer_config + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +TPSIZE = 1 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) +def run_chatglm2_test(test_config): + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + # pad_token_id = 0 + model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) + orig_model = model_fn() + orig_model = orig_model.half() + text = ["how is the weather today?"] + input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, **generate_kwargs) + assert outputs is not None + + # print("outputs.shape: ", outputs[0].shape) + # print("outputs: ", outputs[0]) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) + + +def check_chatglm2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm2_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm2(): + spawn(check_chatglm2, TPSIZE) + + +if __name__ == "__main__": + test_chatglm2() diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py new file mode 100644 index 000000000000..c22f70211d4f --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama2_token_attn.py @@ -0,0 +1,65 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + + logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) + prob = torch.softmax(logics, dim=1) + prob = prob.view(bs, seqlen, num_head, 1) + + return torch.sum(prob * xv, dim=1, keepdim=False) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test(): + Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 + dtype = torch.float16 + + # attn out: 2,4096 + q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty_like() + # o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + + max_kv_cache_len = seq_len + kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + other_kv_index = 2048 + + kv_cache_seq_len[:] = seq_len + kv_cache_start_loc[0] = 0 + kv_cache_start_loc[1] = seq_len + + for i in range(Z): + kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") + + Llama2TokenAttentionForwards.token_attn( + q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index + ) + torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) + assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) + + +if __name__ == "__main__": + test() From 4146f1c0ceea4de649f8639d11c26901187cb294 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 22 Sep 2023 18:29:17 +0800 Subject: [PATCH 39/58] [release] update version (#4775) * [release] update version * [doc] revert versions --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index d15723fbe8de..1c09c74e221c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.2 +0.3.3 From 74aa7d964a8fbb9a9a4865ecd9ac2bda817c3ef2 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Sun, 24 Sep 2023 23:12:26 +0800 Subject: [PATCH 40/58] initial commit: add colossal llama 2 (#4784) --- applications/Colossal-LLaMA-2/README.md | 377 +++++++++++++++++ .../colossal_llama2/__init__.py | 2 + .../colossal_llama2/dataset/__init__.py | 2 + .../colossal_llama2/dataset/loader.py | 219 ++++++++++ .../dataset/spliced_and_tokenized_dataset.py | 183 +++++++++ .../colossal_llama2/model/init_model.py | 111 +++++ .../tokenizer/init_tokenizer.py | 98 +++++ .../colossal_llama2/utils/__init__.py | 2 + .../colossal_llama2/utils/ckpt_io.py | 88 ++++ .../utils/flash_attention_patch.py | 216 ++++++++++ .../colossal_llama2/utils/froze.py | 18 + applications/Colossal-LLaMA-2/docs/example.md | 245 +++++++++++ .../Colossal-LLaMA-2/hostfile.example | 2 + .../prepare_pretrain_dataset.py | 153 +++++++ .../Colossal-LLaMA-2/requirements.txt | 15 + .../Colossal-LLaMA-2/train.example.sh | 44 ++ applications/Colossal-LLaMA-2/train.py | 383 ++++++++++++++++++ applications/Colossal-LLaMA-2/version.txt | 1 + applications/README.md | 5 +- 19 files changed, 2162 insertions(+), 2 deletions(-) create mode 100644 applications/Colossal-LLaMA-2/README.md create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/__init__.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py create mode 100644 applications/Colossal-LLaMA-2/docs/example.md create mode 100644 applications/Colossal-LLaMA-2/hostfile.example create mode 100644 applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py create mode 100644 applications/Colossal-LLaMA-2/requirements.txt create mode 100644 applications/Colossal-LLaMA-2/train.example.sh create mode 100644 applications/Colossal-LLaMA-2/train.py create mode 100644 applications/Colossal-LLaMA-2/version.txt diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md new file mode 100644 index 000000000000..7274abbad0f5 --- /dev/null +++ b/applications/Colossal-LLaMA-2/README.md @@ -0,0 +1,377 @@ +
+

+ +

+
+ +## Table of Contents +- [News](#news) +- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b) + - [Performance Evaluation](#performance-evaluation) + - [Examples](#examples) + - [Training Logs](#training-logs) + - [Import from Transformers](#import-from-transformers) +- [Usage](#usage) + - [Install](#install) + - [How to run](#how-to-run) +- [Technical Insight](#technical-insights) + - [Data](#data) + - [Tokenizer](#tokenizer) + - [Training Strategy](#training-strategy) +- [Citations](#citations) + +## News +* [2023/09] 🔥 TODO We released **Colossal-LLaMA-2-7B-base** based on LLaMA-2. [Download weights](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base). + +## Colossal-LLaMA-2-7B +The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks. + +Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others. + +### Performance Evaluation +We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models. + +* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token. +* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token. +* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score. +* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token. +The generation config for all dataset is greedy search. +* We also provided CEval scores from its lastest leaderboard or the official repository of the model. + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + +> The score in parentheses corresponds to the scores in the official repository of the model. +> +> We use zero-shot for ChatGLM models. +> +> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. +> +> For other models and other dataset, we calculate logits over "A", "B", "C" and "D". + +❗️ More details of the evaluation methods and reproduction of the results, please refer to [TODO: ColossalEval](). + +### Examples +| Question Type | Question |
Colossal-LLaMA-2-7b-base
| +| :------: | :-----------------------: | :----------- | +| Knowledge Quiz
(Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | +| Knowledge Quiz
(Computer Science) | 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | +| Knowledge Quiz
(Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 | +| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | +| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | +| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:

1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.

Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg

Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg

So, the weight of 2 books is 1kg. | +| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | +| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | + +❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md). + +### Training Logs +We also recorded the training logs for the experiment + +

+ +

+ +

+ +

+ +### Import from Transformers +To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code: +```Python +from transformers import AutoModelForCausalLM, AutoTokenizer +model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True) +input = "离离原上草," +inputs = tokenizer(input, return_tensors='pt') +inputs = inputs.to('cuda:0') +pred = model.generate(**inputs, + max_new_tokens=256, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1) +print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):]) +``` + +You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base). + +## Usage +### Install + +#### 0. Pre-requisite +1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server. + +2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1. + + +#### 1. Install required packages +``` +cd Colossal-LLaMA-2 +pip install -r requirements.txt +``` +#### 2. Install `xentropy`, `layer_norm` and `rotary` +```bash +git clone git@github.com:Dao-AILab/flash-attention.git +# At the root folder +cd csrc/xentropy && pip install . +# At the root folder +cd csrc/layer_norm && pip install . +# At the root folder +cd csrc/rotary && pip install . +``` + +### How to run + +#### 1. Init Tokenizer Preparation +Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows: +```json +{"piece": "你好"} +{"piece": "人工智能"} +``` +Command to initialize new tokenizer: +```bash +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python' +python colossal_llama2/tokenizer/init_tokenizer.py \ + --source_tokenizer_dir "" \ + --target_tokenizer_dir "" \ + --expand_tokens_file ".jsonl" +``` +Here is details about CLI arguments: +* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`. +* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer. +* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer. + +#### 2. Init Model Preparation +Initialize the new model checkpoint by calculating the mean values from the original model checkpoint. +Command to initialize new model checkpoint: +```bash +python colossal_llama2/model/init_model.py \ + --source_model_and_tokenizer_path "" \ + --target_tokenizer_path "" \ + --target_model_path "" +``` +"" can be the same as "". + +Here is details about CLI arguments: +* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format. +* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step. +* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format. + +❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder. + +#### 3. Data Preparation +Raw data should be formatted as `jsonl` format. Each data point should have the following fields: +* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty. +* `target` (str, compulsory): Loss will be calculated. +* `category` (str, compulsory): Tags for each data point. + +Examples: +```JSON +{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"} +{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"} +``` +You are allowed to customize the category tags or use `unknown` to define the category. + +Command to convert jsonl dataset to arrow format: +``` +python prepare_pretrain_dataset.py \ + --data_input_dirs ",," \ + --tokenizer_dir "" \ + --data_cache_dir "jsonl_to_arrow_cache" \ + --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ + --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --max_length 4096 \ + --num_spliced_dataset_bins 10 +``` +Here is details about CLI arguments: +* Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format. +* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format. +* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally. +* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format. +* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly. +* Max length: `max_length`. Max length of spliced samples. Default value is 4096. +* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training. + +#### 4. Command Line Arguments for Training +You can use `colossalai run` to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +pretrain.py --OTHER_CONFIGURATIONS +``` +Here is a sample hostfile: +```bash +hostname1 +hostname2 +hostname3 +hostname4 +``` +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: +* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format. +* Dataset path: `--dataset`. Path to the pre-tokenized dataset. +* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). +* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. +* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. +* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs. +* Configuration file: `--config_file`. The path to save the configuration file. +* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1. +* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1. +* Learning rate: `--lr`. The default value is 3e-4. +* Max length: `--max_length`. Max context length. The default value is 4096. +* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +* Gradient clipping: `--gradient_clipping`. The default value is 1.0. +* Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio. +* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. +* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size. +* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1. +* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. + +#### 5. Running Command +An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment: +* Create your own hostfile: `cp hostfile.example hostfile`. +* Create your own bash: `cp train.example.sh train.sh`. +* Add your real host ip or host name into the `hostfile`. +* Update global variables and parameters in your `train.sh`. +* Run the experiment by `bash train.sh` + +Here is the details about global variables for each experiment: +* `PROJECT_NAME`: Project name for each experiment. +* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint. +* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs. +* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment. +* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint. +* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g., +```python +declare -a dataset=( + "/part-00000" + "/part-00001" + "/part-00000" +) +``` +## Technical Insights +In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows: + +

+ +

+ +### Data +Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset. + +The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2. +

+ +

+ +❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned! + +### Tokenizer +The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters. + +To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices. + +Advantages of extending vocabulary size: +* Improve the compression rate of string sequence encoding. +* Enhance the integrity of information. +* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding. + +Advantages of large vocabulary size under low-resource settings: +* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned. +* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process. + +To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level. + +| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) | +| :-----------: | :---------: | :----: | :----: | +| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 | +| LLaMA-2-7B | 32000 | 1.205 | 134.689 | +| Atom-7B | 65000 | 0.634 | 70.915 | +| Baichuan-7B | 64000 | 0.678 | 75.857 | +| Baichuan2-7B-base | 125696 | 0.570 | 63.761 | +| Chatglm2-6B | 64789 | 0.645 | 72.178 | +| InternLM-7B | 103168 | 0.566 | 63.349 | +| Qwen-7B | 151643 | 0.578 | 64.703 | +| Tigerbot-7B-base | 60515 | 0.630 | 70.515 | +| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 | +| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 | +| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 | +| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 | +| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 | + + +### Training Strategy +#### Multi-stage Training +In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages. + +Therefore, we have divided the training process into three stages: +* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens. +* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language. +* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains. + +Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks. + +The following figure illustrates the three stages for training Colossal-LLaMA-2. + +

+ +

+ +#### Bucket-based Training +Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2. + +In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset. + +## Citations +```bibtex +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` +```bibtex +@misc{touvron2023llama, + title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, + author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom}, + year={2023}, + eprint={2307.09288}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +```bibtex +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author={Dao, Tri}, + year={2023} +} +} +``` + + diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py new file mode 100644 index 000000000000..56fafa58b3f4 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py new file mode 100644 index 000000000000..56fafa58b3f4 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py new file mode 100644 index 000000000000..a2cfb2ef6264 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import os +import random +from dataclasses import dataclass +from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable + +import torch +from datasets import dataset_dict, load_from_disk +from datasets import Dataset as HFDataset +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer +import torch.nn.functional as F + +DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] +PathType = Union[str, os.PathLike] + + +def load_tokenized_dataset( + dataset_paths: Union[PathType, List[PathType]], mode: str = "train" +) -> Optional[DatasetType]: + """ + Load pre-tokenized dataset. + Each instance of dataset is a dictionary with + `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. + """ + mode_map = {"train": "train", "dev": "validation", "test": "test"} + assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" + + if isinstance(dataset_paths, (str, os.PathLike)): + dataset_paths = [dataset_paths] + + datasets = [] # `List[datasets.dataset_dict.Dataset]` + for ds_path in dataset_paths: + ds_path = os.path.abspath(ds_path) + assert os.path.exists(ds_path), f"Not existed file path {ds_path}" + ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False) + if isinstance(ds_dict, HFDataset): + datasets.append(ds_dict) + else: + if mode_map[mode] in ds_dict: + datasets.append(ds_dict[mode_map[mode]]) + if len(datasets) == 0: + return None + if len(datasets) == 1: + return datasets.pop() + return ConcatDataset(datasets=datasets) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """ + Collate instances for supervised dataset. + Each instance is a tokenized dictionary with fields + `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + ignore_index: int = -100 + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + + # `List[torch.Tensor]` + batch_input_ids = [ + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + for instance in instances + ] + batch_labels = [ + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + for instance in instances + ] + + if self.tokenizer.padding_side == "right": + input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=batch_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + labels = torch.nn.utils.rnn.pad_sequence( + sequences=batch_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + elif self.tokenizer.padding_side == "left": + reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] + reversed_input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len) + reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels] + reversed_labels = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len) + else: + raise RuntimeError( + f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, " + f"but now `{self.tokenizer.padding_side}`" + ) + + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + + return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +class StatefulDistributedSampler(DistributedSampler): + """ + Stateful distributed sampler for multi-stage training. + """ + + def __init__( + self, + dataset: DatasetType, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + self.start_index = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def setup_distributed_dataloader( + dataset: DatasetType, + batch_size: int = 1, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, + process_group: Optional[ProcessGroup] = None, + **kwargs, +) -> DataLoader: + """ + Setup dataloader for distributed training. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset=dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + + # Deterministic dataloader + def seed_worker(worker_id: int) -> None: + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + worker_init_fn=seed_worker, + **_kwargs, + ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py new file mode 100644 index 000000000000..0c21f325ae62 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Splicing multiple pre-tokenized sequence data points +""" + +import random +import warnings +from copy import deepcopy +from datasets import dataset_dict +from typing import Any, Callable, Dict, Iterable, List, Union, Tuple + +from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers.tokenization_utils import PreTrainedTokenizer + +IGNORE_INDEX = -100 + +DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] + + +def supervised_tokenize( + data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096 +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following: + {"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"} + """ + assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, ( + "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, " + "add and manually later" + ) + if ignore_index is None: + ignore_index = IGNORE_INDEX + + source_text = data_point["source"] # `str` + target_text = data_point["target"] # `str` + is_null_source = len(source_text) == 0 + + source_text = tokenizer.bos_token + source_text + target_text += tokenizer.eos_token + sequence_text = source_text + target_text + + tokenized = tokenizer([source_text, sequence_text])["input_ids"] + sequence_input_ids = tokenized[1] + sequence_labels = deepcopy(sequence_input_ids) + + source_length = len(tokenized[0]) + if not is_null_source: + sequence_labels[:source_length] = [ignore_index for _ in range(source_length)] + + # sequence truncation. + if len(sequence_input_ids) > max_length: + sequence_input_ids = sequence_input_ids[:max_length] + sequence_labels = sequence_labels[:max_length] + + return dict( + input_ids=sequence_input_ids, + labels=sequence_labels, + seq_length=len(sequence_input_ids), + seq_category=data_point["category"], + ) + + +class ClosedToConstantLengthSplicedDataset(IterableDataset): + """ + Define an iterable dataset that returns a (close to) constant length data point spliced from multiple + original independent (pre-tokenized) data points. + """ + + def __init__( + self, + dataset: DSType, + tokenizer: PreTrainedTokenizer, + max_length: int = 4096, + num_packed_sequences: int = 8, + fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None, + input_ids_field: str = "input_ids", + labels_field: str = "labels", + infinite: bool = False, + shuffle: bool = True, + error_strict: bool = False, + ) -> None: + self.tokenizer = tokenizer + self.dataset = dataset + self.max_length = max_length + self.infinite = infinite + self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16 + self.shuffle = shuffle + + # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]], + # A function that fetch sequence input_ids and labels from the original data point + if fetch_sequence_func is None: + self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field]) + else: + self.fetch_sequence_func = fetch_sequence_func + self.input_ids_field = input_ids_field + self.labels_field = labels_field + + self.error_strict = error_strict + self.current_size = 0 # `int`, current packed data size. + + def __len__(self) -> int: + return len(self.dataset) + + def __iter__(self) -> Iterable[Dict[str, List[int]]]: + iterator = iter(self.dataset) + more_data_points = True + while more_data_points is True: + buffer, buffer_len = [], 0 + while True: + # ending condition. + if buffer_len >= self.max_buffer_size: + break + try: + # `Tuple[List[int], List[int]]` + seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator)) + buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels}) + buffer_len += len(buffer[-1][self.input_ids_field]) + except StopIteration: + if self.infinite is True: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_data_points = False + break + examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points. + spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]` + for i, data_point in enumerate(buffer): + # TODO(2023-09-18) check errors for each unspliced tokenized data point + seq_input_ids = data_point[self.input_ids_field] + seq_labels = data_point[self.labels_field] + # Handle special case: + # If the length of an original data point (i.e., input_ids length of a data point before splicing) + # exceeds `max_length`, truncate it. + if len(seq_input_ids) > self.max_length: + truncated_seq_input_ids = seq_input_ids[: self.max_length] + truncated_label_ids = seq_labels[: self.max_length] + if set(truncated_label_ids) == {IGNORE_INDEX}: + if self.error_strict is True: + raise ValueError( + f"Find an out-of-bounds length({len(seq_input_ids)}) data point " + f"with all label values as {IGNORE_INDEX}." + ) + else: + warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})") + continue # Skip the current error data point. + spliced_data_point = { + self.input_ids_field: truncated_seq_input_ids, + self.labels_field: truncated_label_ids, + } + examples.append(spliced_data_point) + warnings.warn("Find a data point to be truncated.") + continue + + # Pre action judgment. + if len(spliced_input_ids) + len(seq_input_ids) > self.max_length: + spliced_data_point = { + self.input_ids_field: spliced_input_ids, + self.labels_field: spliced_labels, + } # `Dict[str, List[int]]` + # Update. + spliced_input_ids, spliced_labels = [], [] + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + examples.append(spliced_data_point) + else: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + # For residual spliced data point at the end of the data set + if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: + examples.append( + { + self.input_ids_field: spliced_input_ids, + self.labels_field: spliced_labels + } + ) + if self.shuffle: + random.shuffle(examples) + for spliced_data_point in examples: + # TODO(2023-09-18): check errors for each spliced tokenized data point. + self.current_size += 1 + yield spliced_data_point diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py new file mode 100644 index 000000000000..67e487f43b08 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Initialize new model with updated tokenizer by calculating the mean values from original model +""" +import argparse + +import numpy as np +import torch +from transformers import LlamaTokenizer, LlamaForCausalLM + +from colossalai.logging import get_dist_logger + + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--source_model_and_tokenizer_path", + type=str, + required=True, + default=None, + help="Source path of model & tokenizer", + ) + parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path") + parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path") + args = parser.parse_args() + + source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path) + source_tokenizer.add_bos_token = False + source_tokenizer.add_eos_token = False + if source_tokenizer.pad_token is None: + source_tokenizer.pad_token = source_tokenizer.unk_token + source_vocab = source_tokenizer.get_vocab() + + target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path) + target_tokenizer.add_bos_token = False + target_tokenizer.add_eos_token = False + if target_tokenizer.pad_token is None: + target_tokenizer.pad_token = target_tokenizer.unk_token + target_vocab = target_tokenizer.get_vocab() + target_inverted_vocab = {v: k for k, v in target_vocab.items()} + + assert len(target_vocab) > len( + source_vocab + ), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})" + + gpu_device = torch.device("cuda:0") + cpu_device = torch.device("cpu") + + source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path) + source_model.eval() + source_model = source_model.to(gpu_device) + + source_input_embeddings = source_model.get_input_embeddings() + assert isinstance(source_input_embeddings, torch.nn.Embedding) + assert source_input_embeddings.weight.shape[0] == len(source_vocab) + source_input_embeddings.eval() + + source_output_embeddings = source_model.get_output_embeddings() + assert isinstance(source_output_embeddings, torch.nn.Linear) + assert source_output_embeddings.bias is None + assert source_output_embeddings.weight.shape[0] == len(source_vocab) + source_output_embeddings.eval() + + input_embeddings = source_input_embeddings.weight.cpu().detach().numpy() + output_embeddings = source_output_embeddings.weight.cpu().detach().numpy() + for i in range(len(source_vocab), len(target_vocab)): + if i % 500 == 0: + logger.info(f"processing {i}/{len(target_vocab)} target tokens") + target_token = target_inverted_vocab[i] + target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0]) + target_to_source_token_ids = target_to_source_token_ids.to(gpu_device) + + target_to_source_input_embedding = ( + source_input_embeddings.weight[target_to_source_token_ids] + .mean(dim=0) + .unsqueeze(dim=0) + .cpu() + .detach() + .numpy() + ) + target_to_source_output_embedding = ( + source_output_embeddings.weight[target_to_source_token_ids] + .mean(dim=0) + .unsqueeze(dim=0) + .cpu() + .detach() + .numpy() + ) + + input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0) + output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0) + + source_model = source_model.to(cpu_device) + assert isinstance(source_model, LlamaForCausalLM) + + # expand + source_model.resize_token_embeddings(new_num_tokens=len(target_vocab)) + source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings) + source_model.lm_head.weight.data = torch.Tensor(output_embeddings) + + source_model = source_model.half() + source_model.save_pretrained(save_directory=args.target_model_path) + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py new file mode 100644 index 000000000000..43297633db1a --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +""" +Initialize new tokenizer for continual pre-training +""" + +import argparse +import os +import json +from typing import List, Union + +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model + +from colossalai.logging import get_dist_logger + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +logger = get_dist_logger() + + +def expand_vocab_tokenizer( + source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str] +) -> None: + """Expand tokenizer for continue pre-training.""" + if os.path.exists(target_tokenizer_dir): + raise RuntimeError(f"Find existed directory {target_tokenizer_dir}") + + source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir) + logger.info(source_tokenizer) + source_sp_processor = source_tokenizer.sp_model + source_spm = sp_pb2_model.ModelProto() + source_spm.ParseFromString(source_sp_processor.serialized_model_proto()) + + logger.info(f"Source tokenizer size: {len(source_sp_processor)}") + + # Add new tokens to source tokenizer. + source_spm_tokens = set([p.piece for p in source_spm.pieces]) + for piece in new_tokens: + assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}" + if piece in source_spm_tokens: + # Skip existed token. + continue + new_p = sp_pb2_model.ModelProto().SentencePiece() + new_p.piece = piece + new_p.score = 0 + source_spm.pieces.append(new_p) + logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}") + + # Save + os.makedirs(target_tokenizer_dir) + target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model") + with open(file=target_tokenizer_model_path, mode="wb") as fp: + fp.write(source_spm.SerializeToString()) + + target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path) + target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir) + logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory" + ) + parser.add_argument( + "--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory" + ) + parser.add_argument( + "--expand_tokens_file", + type=str, + required=True, + default=None, + help="Path of the file containing tokens to be extended", + ) + args = parser.parse_args() + + expand_tokens = [] + with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader: + for line in fp_reader: + item = json.loads(line) + # e.g., {"piece": "你好"} + token = item["piece"] + if token in expand_tokens: + continue + expand_tokens.append(token) + expand_tokens.sort(key=lambda t: len(t), reverse=False) + + expand_vocab_tokenizer( + source_tokenizer_dir=args.source_tokenizer_dir, + target_tokenizer_dir=args.target_tokenizer_dir, + new_tokens=expand_tokens, + ) + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py new file mode 100644 index 000000000000..56fafa58b3f4 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py new file mode 100644 index 000000000000..85decf37dd0b --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Helper functions for IO +""" + +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.optimizer import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py new file mode 100644 index 000000000000..6c58c59307a6 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaAttention, + LlamaModel, + LlamaForCausalLM, + apply_rotary_pos_emb, + repeat_kv, +) + +from colossalai.logging import get_dist_logger +from einops import rearrange + +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_kvpacked_func, +) +from flash_attn.ops.rms_norm import rms_norm + + +logger = get_dist_logger() + + +def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, +) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask + + +def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) + ) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) + q, k, v = ( + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), + ) + ) + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len + + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) + + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, + ), + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + + +def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + + +def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py new file mode 100644 index 000000000000..82677160d868 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from transformers.models.llama import LlamaForCausalLM + + +def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None: + """Freeze all parameters except embeddings.""" + for name, params in model.named_parameters(): + if "embed_tokens" not in name and "lm_head" not in name: + params.requires_grad = False + else: + params.requires_grad = True + + +def unfreeze_parameters(model: LlamaForCausalLM) -> None: + for name, params in model.named_parameters(): + params.requires_grad = False diff --git a/applications/Colossal-LLaMA-2/docs/example.md b/applications/Colossal-LLaMA-2/docs/example.md new file mode 100644 index 000000000000..d889ab4165d0 --- /dev/null +++ b/applications/Colossal-LLaMA-2/docs/example.md @@ -0,0 +1,245 @@ +# Colossal-LLaMA-2-7B-base Examples +To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, our team conducted human evaluations across various knowledge domains and tasks. These tasks encompassed Knowledge QA in 10 different areas, Content Generation, Brainstorming, Summarization, Sentiment Analysis, Logical Reasoning, Information Extraction, Role-play, Classification, and Rewriting. We also conducted a comparative analysis, pitting the generation results of Colossal-LLaMA-2-7B-base against Qwen-7b-base, a recently-pretrained model known for its impressive performance. It's worth noting that both models shared identical inference parameters. For each question, we generated model responses multiple times and selected the best one as the final answer. + +
+ +## Table of Contents +- [Examples](#knowledge-qa) + - [Knowledge QA](#knowledge-qa) + - [Travel](#travel) + - [Medical](#medical) + - [Food](#food) + - [Computer Science](#computer-science) + - [Game](#game) + - [Sports](#sports) + - [Finance](#finance) + - [Law](#law) + - [Technology](#technology) + - [Literature](#literature) + - [Content Generation](#content-generation) + - [Writing](#writing) + - [Code](#code) + - [E-mail](#e-mail) + - [LaTex Table](#latex-table) + - [Brainstorming](#brainstorming) + - [Summarization](#summarization) + - [Sentiment Analysis](#sentiment-analysis) + - [Logical Reasoning](#logical-reasoning) + - [Information Extraction](#information-extraction) + - [Role-play](#role-play) + - [Classification](#classification) + - [Rewriting](#rewriting) +- [Conclusion](#conclusion) + + +### Knowledge QA +#### Travel +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 中国最适合家庭度假的旅游目的地是哪里? | 1. 中国最适合家庭度假的旅游目的地是海南。海南是中国唯一的热带岛屿省份,素有“东方夏威夷”之称。2. 三亚。三亚市位于海南岛的最南端,北与东方市交界,东与陵水县相邻,南与乐东县接壤,西濒南海,是一个风光绮丽的热带滨海旅游城市。 | 中国有非常多的适合家庭度假的目的地,以下是一些推荐:
1. 海南岛:海南岛有得天独厚的气候和风景,是一个度假的好去处,不仅有美丽的海滩和沙滩,还有很多可以游玩的景点,比如南山寺、五指山等。
2. 三亚:三亚是一个受欢迎的旅游城市,有美丽的海滩和浪漫的夕阳,还有丰富的海鲜美食。您可以在这里参加水上活动,比如滑水、帆板等。
3. 青岛:青岛是一个充满活力和美丽的城市,有美丽的海滩、博物馆、美食和啤酒,可以带孩子们一起去海边玩沙子,也可以去博物馆参观。
4. 成都:成都是中国西南地区的一个大城市,不仅有丰富的历史和文化,还有美食和美景。您可以去到大熊猫基地和锦里古街,还可以品尝到地道的川菜。
5. 长沙:长沙是一个历史悠久的城市,有丰富多彩的文化和历史景点,比如岳麓山、橘子洲头和岳麓书院等。此外,长沙的美食也是非常有名的,您可以在长沙尝到辣子鸡等当地美食。

以上是几个适合家庭度假的目的地,希望能给您提供一些参考。 | +| 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | Assistant:
武汉的最佳旅行季节是春季和秋季。春季气温宜人,可以赏樱花,秋季天气宜人,可以欣赏红叶。夏季温度较高,游客众多,冬季则有雪景和温泉可赏,但冬季的交通和住宿可能会更加拥挤。 | +| What are amazing islands near Singapore that could be visited? | Singapore is one of the most developed and well-connected cities in the world, with an impressive skyline and modern infrastructure. However, there are many amazing islands near Singapore that are worth visiting. Here are some of the most popular and beautiful islands near Singapore:

1. Sentosa Island: Sentosa Island is a small island located off the southern coast of Singapore. It is a popular tourist destination, known for its beaches, theme parks, and lush greenery. Visitors can enjoy a range of activities on the island, including sunbathing, swimming, and exploring the jungles.

2. Pulau Ubin: Pulau Ubin is a small island located about 45 minutes away from Singapore by boat. It is a popular destination for nature lovers, offering hiking trails, cycling paths, and a variety of wildlife. Visitors can also explore the fishing village and try their hand at fishing.

3. Pulau Tekong: Pulau Tekong is a small island located about 30 minutes away from Singapore by boat. It is a former military base, but is now a popular destination for nature lovers. Visitors can enjoy hiking trails, cycling paths, and a variety of wildlife.

4. Pulau Sibuyu: Pulau Sibuyu is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for water sports enthusiasts, offering snorkeling, diving, and kayaking. Visitors can also explore the fishing village and try their hand at fishing.

5. Pulau Rapa: Pulau Rapa is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for birdwatchers, offering a variety of bird species and birdwatching spots. Visitors can also explore the fishing village and try their hand at fishing.
Overall, these are some of the amazing islands near Singapore that are worth visiting. Whether you are looking for adventure, relaxation, or simply want to explore the natural beauty of Singapore, these islands offer something for everyone. | "What are some good islands near Singapore with a short distance? Which one is recommended?"

Please tell me if those questions are the same.
OPTIONS:
- no
- yes

no | +| Can you recommend local place to eat breakfast and dinner in Shanghai? | Yes, I can certainly recommend some great local places to eat breakfast and dinner in Shanghai. Here are some suggestions:

1. Dumpling House: This popular restaurant serves delicious dumplings and other Chinese dishes for breakfast and dinner. They have multiple locations throughout Shanghai, so you can easily find one near you.

2. Yuyuan Brewers Fresh Beer Restaurant: This unique restaurant offers a variety of fresh beers and delicious food, including breakfast and dinner options. It's a fun and relaxed atmosphere to enjoy your meal.

3. Zenbox Coffee & Bakery: This trendy coffee shop is the perfect spot to start your day with a delicious breakfast sandwich or pastry. They also have a great selection of lunch options, so you can continue enjoying your meal in the afternoon.

4. Shanghai Taste: This authentic Shanghai restaurant offers traditional Chinese dishes, including noodles and stir-fry dishes, for breakfast and dinner. The cozy atmosphere and friendly service make it a great choice for a meal.

5. XLR8 Family Cafe: This family-friendly cafe offers breakfast and dinner options, including a healthy salad bar. They also have a kids' menu with plenty of options for little ones.

These are just a few examples of the many great places to eat breakfast and dinner in Shanghai. With so many options to choose from, you're sure to find something that suits your taste and budget. | I am traveling on budget and will have access to public transportation to reach the location.

Assistant: Sure! Here are some local places to eat breakfast and dinner in Shanghai that you might want to consider:

Breakfast:

1. Fresh on the Square - They have a variety of breakfast options, including sandwiches, smoothies, and porridge. The prices are affordable, and they have a location at People's Square.
2. The Dumpling House - This restaurant is known for its delicious dumplings and offers an extensive menu of Chinese breakfast items, such as baozi, soy milk, and jianbing. The location is near the Jing'an Temple subway station.
3. Shanghainese Roast Duck - This restaurant is famous for its crispy duck, but they also have breakfast items like rice porridge, nian gao (sticky rice cake), and soy milk. The location is near the Xinzhuang Road subway station.

Dinner:

1. Lost Heaven - This restaurant offers a fusion of Chinese and Western cuisine in a beautiful garden setting. Prices are on the higher side, but the quality of food is excellent.
2. Shanghainese Restaurant - This restaurant is a must-try if you want to experience authentic Shanghai food. They have a variety of Shanghainese dishes, such as hairy crab, xiao long bao (steamed soup dumplings), and braised pork belly. The location is near the Hongqiao Railway Station.
3. Lao Da Zhou - This restaurant is known for its delicious xiao long bao and other Shanghainese dishes. Prices are reasonable, and they have a location near the Zhongshan Park subway station.

I hope this helps you plan your budget-friendly food itinerary in Shanghai! | +
+ + +#### Medical +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 甲状腺手术后多久可以干活? | 甲状腺手术后三个月左右可以干活。甲状腺手术要根据患者的病情和手术方式来决定,如果患者的症状较轻,病情比较轻微,选择的是微创手术,一般手术后一个月就可以干活。若患者的症状比较严重,病情得不到控制,需要对甲状腺部位进行常规手术来控制病情,这种手术的创伤比较大,可能需要两个月左右才能完全恢复,恢复后才能干活。平时要注意不能让甲状腺部位受到外力刺激,以免发生疼痛和出血症状,饮食要以清淡为主,尽量选择易消化的食物来帮助身体恢复,避免摄入一些辛辣刺激类的食品,要多吃蔬菜水果,多喝温水。 | 这个问题需要根据具体情况而定。一般来说,甲状腺手术后需要至少休息2-4周,期间不能剧烈运动,同时还要注意术后恢复和饮食调理,以保证身体恢复。在医生的指导下,逐渐恢复工作和生活。建议术后恢复期间多休息,避免过度疲劳,多食水果、蔬菜等清淡食物,保持良好的心态,定期复查,以确保术后恢复良好。 | +| 喉咙发炎引起的发烧38.5摄氏度,怎么办? | 喉咙发炎导致发烧38.5摄氏度,应该要根据实际病情采取以下几种措施:1.对于发热38.5摄氏度,要及时遵医嘱给予退热药物治疗,例如布洛芬胶囊、尼美舒利片等,同时多喝水,有利于体温的快速下降。2.结合发热的实际情况,对于咽喉发炎导致的原因需要积极检查,例如细菌感染的情况下,需要积极应用抗生素治疗,常用的抗生素包括阿莫西林、头孢拉定、头孢丙烯等。3.在饮食上需要注意,不能吃辛辣刺激性的食物,例如辣椒等。需要多吃一些清淡易消化的食物,例如面条、青菜、豆腐、小米粥等。4.如果嗓子痛的比较严重,可以选择布地奈德混悬液或者地塞米松注射液进行雾化吸入治疗。以上方案仅供参考,具体药品使用请结合自身情况在专业医生指导下用药。 | 请医生帮忙。

您好,根据您提供的症状,建议您尽快去医院就诊,让医生进行详细的检查和诊断。对于喉咙发炎引起的发烧,可以采取以下措施缓解症状:

1. 多喝水:保持充足的水分摄入可以缓解喉咙疼痛和喉咙发炎的症状。

2. 休息:多休息可以帮助身体恢复,缓解疲劳和疼痛。

3. 温水漱口:用温水漱口可以缓解喉咙疼痛和喉咙发炎的症状。

4. 温度适宜:保持室内温度适宜,避免过度受凉。

5. 吃清淡易消化的食物:避免辛辣、油腻、刺激性食物,多吃清淡易消化的食物。

请注意,以上方法只是缓解症状的方法,不能替代医生的诊断和治疗。建议您尽快就医,以便早日恢复健康。 | +| How can one treat intense pain in the stomach while suffering from ulcer? | Intense pain in the stomach is a common symptom of ulcers, and it can be quite uncomfortable and painful. It is essential to seek medical attention as soon as possible to obtain an accurate diagnosis and appropriate treatment. Treatment for ulcer-related pain usually involves managing pain with medication, addressing the underlying cause of the ulcer, and maintaining a healthy lifestyle.

1. Manage pain with medication:

There are several over-the-counter pain relievers available that can help alleviate symptoms. Examples of these medications include nonsteroidal anti-inflammatory drugs (NSAIDs), such as ibuprofen and naproxen, as well as acetaminophen (Tylenol). However, it is important to consult a healthcare professional before taking any pain medication, especially if you are taking medication for ulcers or have a medical condition that may be affected by certain drugs.

2. Address the underlying cause of the ulcer:

Ulcers can be caused by a variety of factors, including bacterial infections, H. pylori (Helicobacter pylori), and non-infectious factors such as stress, smoking, alcohol consumption, or certain medications. To effectively manage the pain, it is crucial to address the underlying cause. In cases of H. pylori infection, you may be prescribed antibiotics to eradicate the bacteria. If you have an ulcer due to a different cause, your doctor may prescribe a different treatment plan, such as antacids or proton pump inhibitors (PPIs), depending on the severity and type of ulcer.

3. Maintain a healthy lifestyle:

To promote healing and reduce the risk of future ulcers, it is essential to maintain a healthy lifestyle. Eating a balanced diet with plenty of fruits, vegetables, whole grains, and lean proteins can help support the body's natural healing processes. Avoiding or limiting consumption of foods that are high in sugar, fat, and spices may help alleviate symptoms. Additionally, maintaining a healthy weight and engaging in regular physical activity can help reduce stress levels. | I am suffering from ulcer since 5 months. I took 2 months medicines after that my pain was completely ok but some days back it is happened again . I also take medicines which doctor have prescribed. Now I am suffering from stomach pain and acidity. Please tell me what to do and what can't I do in this situation?
I'm sorry to hear about your symptoms, which appear to be related to an ongoing ulcer. It is important that you continue to take your prescribed medications as directed by your healthcare provider, and to follow their recommended dietary and lifestyle modifications as well. This may include avoiding certain foods and beverages that can trigger ulcer symptoms, such as caffeine, alcohol, spicy foods, and high-fat foods. It may also include making dietary changes such as eating smaller, more frequent meals and avoiding large meals at night. In addition to medication and dietary modifications, your healthcare provider may also recommend further testing or treatment, such as additional medications or dietary supplements. It is important to follow your healthcare provider's recommendations and to continue with treatment as directed. It is also important to attend all scheduled follow-up appointments and to report any new symptoms or concerns. | +
+ + +#### Education +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 家庭教育在孩子的教育中扮演什么角色? | 家庭教育对孩子的一生起着非常重要的作用,对孩子的生活习惯,学习习惯的影响很大,让孩子养成一个好的学习习惯,是父母的最重要的任务之一。 | 1. 培养孩子的性格和价值观
2. 给孩子提供正确的信息和教育
3. 建立良好的亲子关系
4. 促进孩子的社交能力
5. 培养孩子的独立性和责任感
6. 提供适当的支持和指导
7. 鼓励孩子的兴趣和才能
8. 帮助孩子学会管理时间和任务
9. 培养孩子的道德观念和公民意识
10. 与孩子进行开放的交流和互动 | +| 如何提高教育的质量? | 1、转变教育观念、调整教育结构、优化教育方法。2、培养正确的人生观和价值观,是搞好教育的题中之义,是扭转子女教育的偏失的关键。3、教师要熟练业务,忠于职守。这是提高教师教育质量的前提,是保证教育有效实施的关键。4、加强教师思想政治教育和职业道德建设。 | 教育质量的提高是一个复杂的过程,需要采取多种策略和措施。以下是一些可能提高教育质量的建议:

1. 建立良好的教育体系:一个良好的教育体系应该有清晰的教育目标、清晰的教学计划、高质量的师资力量和科学的教育评估体系。

2. 强调学生的全面发展:教育应该注重培养学生的全面素质,不仅包括知识技能的传授,还包括品德、道德、文化、艺术等方面的培养。

3. 采用创新的教学方法:创新的教学方法可以激发学生的兴趣和学习积极性,例如案例教学、项目教学、探究式教学等。

4. 建立与社会需求相适应的教育内容:教育内容应该适应社会需求的变化,例如对于新兴行业和科技领域的关注和研究。

5. 加强教师培训和提高:教师是教育质量的关键因素,因此应该加强教师培训和提高,提高教师的教学能力和职业素养。

6. 建立多元化的教育模式:不同学生有不同的学习需求和方式,因此应该建立多元化的教育模式,例如在线教育、远程教育等。

7. 强化教育评估:教育评估可以提供教育质量改进的依据和指导,应该建立科学的教育评估体系,定期对教育质量进行评估和改进。 | +| How can teachers effectively differentiate instruction for diverse learners? | Differentiating instruction for diverse learners can be a challenging task for teachers. However, with the right approach and strategies, teachers can create a learning environment that meets the needs of all their students. Here are some tips on how to effectively differentiate instruction for diverse learners:

1. Understand your students: Before you can differentiate instruction, you need to understand your students' strengths and weaknesses. Spend time getting to know your students and their individual needs. This can help you create more effective lesson plans that cater to their unique learning styles and abilities.

2. Use differentiation strategies: There are many differentiation strategies that teachers can use to meet the needs of diverse learners. Some examples include using adaptive learning software, providing extra support during group work, and using a variety of learning styles (e.g. auditory, visual, and kinesthetic).

3. Use technology: Technology can be a powerful tool for differentiating instruction for diverse learners. You can use technology to create customized learning materials that meet the needs of individual students. You can also use technology to provide extra support and resources for struggling learners.

4. Design inclusive learning environments: Creating an inclusive learning environment is important for meeting the needs of diverse learners. This means creating a space where all students feel welcome and valued. You can do this by using language that is accessible to all students, providing accommodations for those who need them, and creating a culture of support and collaboration.

5. Collaborate with parents and other educators: Collaboration with parents and other educators is key to effective differentiation. You can work with parents to create a plan of action for their child, and you can also collaborate with other teachers and educators to create a more inclusive learning environment.

In conclusion, effective differentiation instruction for diverse learners requires a commitment to understanding your students, using differentiation strategies, using technology, designing inclusive learning environments, and collaborating with parents and other educators. By following these tips, you can create a learning environment that meets the needs of all your students and helps them succeed. | Provide specific strategies and examples of how teachers can personalize their lessons to meet the individual needs of students

Assistant: Differentiating instruction means tailoring teaching and learning activities to meet the unique needs of each learner in the classroom. Teachers can effectively differentiate instruction for diverse learners by using various strategies, including the following:

1. Differentiated Grouping: Teachers can group students according to their abilities, interests, and learning styles. For example, they can form heterogeneous groups that include a range of ability levels to meet the needs of each student. Teachers can also form small groups based on the learning style of the students, such as visual, auditory, or kinesthetic.

2. Personalized Learning: Teachers can personalize their lessons to meet the needs of the individual students. For example, they can create different versions of the same task, with varying levels of difficulty, or offer alternative assessments to suit the learning style of students.

3. Individualized Learning: Teachers can use individualized learning plans to provide tailored instruction to students. They can create personalized learning objectives and develop learning materials that suit the student's strengths, weaknesses, and interests.

4. Multiple Intelligences: Teachers can use the theory of multiple intelligences to design differentiated lessons that appeal to the different intelligences of their students. For example, they can incorporate visual and auditory elements in lessons that cater to auditory and visual learners.

5. Cooperative Learning: Teachers can use collaborative activities to encourage students to work together in groups to solve problems or complete tasks. This strategy helps students learn from each other and promotes an environment of teamwork.

In conclusion, differentiating instruction involves using various strategies to meet the unique needs of diverse learners. Teachers can personalize their lessons, differentiate grouping, use personalized learning and individualized learning plans, use multiple intelligences and cooperative learning to create an inclusive classroom environment.

Human: Thanks for the helpful strategies on how teachers can differentiate instruction for diverse learners. Can you provide more examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners

Assistant: Certainly! Here are some examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners:

1. Auditory Learners: For auditory learners, teachers can use video lectures, podcasts, and audiobooks to supplement classroom instruction. Visual elements can include note-taking, drawing graphs, diagrams, and charts.

2. Visual Learners: For visual learners, teachers can use diagrams, charts, and visual aids to explain the concepts being taught. Examples of visual aids include: 3D models, animations | +
+ + +#### Food +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请介绍中国的八大菜系和给出每个菜系最有代表性的菜。 | 1.鲁菜——醋椒鱼、扒肉条、一品豆腐;2.川菜——麻婆豆腐、酸菜鱼、宫保鸡丁;3.粤菜——烧鹅、烧鸭、白切鸡;4.苏菜——叫花鸡、松鼠桂鱼、水晶虾仁;5.浙菜——东坡肉、西湖醋鱼、东坡肉;6.湘菜——剁椒鱼头、辣椒炒肉、辣椒炒鸡;7.闽菜——佛跳墙、荔枝肉、太极芋泥;8.徽菜——徽州毛豆腐、臭鳜鱼、徽州一品锅。 | 中国的八大菜系是:
1. 鲁菜:红烧肉、糖醋里脊、宫保鸡丁
2. 粤菜:烧鹅、清蒸鲈鱼、叉烧
3. 川菜:宫爆鸡丁、麻婆豆腐、水煮鱼
4. 苏菜:松鼠鳜鱼、东坡肉、清炒虾仁
5. 浙菜:西湖醋鱼、龙井虾仁、叫化鸡
6. 闽菜:佛跳墙、醉排骨、荔枝肉
7. 徽菜:黄山炖鸽、徽州肉、黄山熏肉
8. 湘菜:剁椒鱼头、口味虾、酸菜鱼
希望这些能帮助您更好地了解中国的八大菜系。 | +| 用牛肉、土豆、萝卜、木薯、罗勒叶、百里香、苹果生成一个西餐食谱。 | 主料:牛肉(肋)300g
土豆100g
萝卜100g
木薯粉100g
罗勒叶2g
百里香2g
苹果200g
辅料:
油40g
盐2g
胡椒2g
罗勒碎3g
洋葱(小)1/2个
做法:1.牛肉切成粒用黄油、胡椒粉和盐抓匀,静置半小时。罗勒和百里香切碎。土豆去皮切块,用黄油和盐炒香。
炒好的土豆装起备用。
锅中倒入黄油,放入洋葱碎。
炒香后加入牛肉。
炒5分钟后加入罗勒碎和百里香碎翻炒。
加入罗勒和百里香碎翻炒。
炒香后加入萝卜块。
炒10分钟后倒入炒好的土豆。
翻炒均匀后加入适量的水,转小火炖煮。
待水烧干后倒入木薯粉。
转小火慢慢熬煮。
熬煮至可以拉成丝即可关火。
把苹果削皮去核,切块。
把煮好的牛肉洋葱苹果土豆泥摆盘上桌。
再在旁边点缀上柠檬。 | 您好!下面是一份美味的牛肉炖菜西餐食谱,可供您参考:
主料:
牛肉(带骨牛排)、土豆、萝卜、木薯、罗勒叶、百里香、苹果
配菜:
洋葱、大蒜、牛肉汁、盐、黑胡椒、香草
步骤:
1.将牛肉洗净备用。
2.洋葱和大蒜切碎,木薯和萝卜切成块状,罗勒叶和百里香洗净备用。
3.热锅凉油,将洋葱和大蒜煸炒至香,放入牛肉煸炒至变色。
4.加入牛肉汁、盐、黑胡椒、香草等调料,煮沸后转小火慢炖30分钟。
5.加入木薯和萝卜块,继续煮15分钟。
6.加入苹果块,煮5分钟。
7.将蔬菜放入碗中,配上热腾腾的牛肉和汤汁即可。
祝您用餐愉快! | +| How to make a classic French omelette? | This traditional French dish combines butter, eggs, and seasoning. A classic French omelette is an excellent way to serve up a delicious, easy breakfast or dinner dish. Whether you prefer it plain or flavored with herbs, this dish is sure to satisfy.

## Prepare the ingredients and cooking equipment

1. Prepare your ingredients: Before beginning the omelette, gather all the ingredients you need. Your ingredients should include butter, flour, eggs, and water.

2. Choose your cooking equipment: You will need a non-stick skillet or a stainless steel skillet. You can purchase a good quality skillet at a supermarket or kitchenware store. However, a great skillet isn’t limited to French omelettes. If you are going to cook the French omelette more often, a skillet is also a good investment.

3. Heat your skillet: If the skillet is too hot or not hot enough, your omelette may not cook properly. Add a small amount of butter to your skillet and swirl it to heat it. When the butter begins to bubble, your skillet is ready for cooking.

4. Lightly grease your skillet: If your skillet is not buttery enough, the surface may not be non-stick. Using a paper towel, lightly grease your skillet before cooking. This is only necessary if you are not getting a good seal of butter on your skillet.


## Cook the French omelette

1. Crack an egg: Crack an egg into a small bowl or measuring cup. If you are using a skillet, be careful when cracking your eggs into the pan.

2. Whisk the egg: Use a fork or whisk to whisk the egg. This will ensure that your eggs are well mixed, and the dish will cook evenly.

3. Add the water: Gently stir the water into the whisked egg. Add only enough water to moisten the eggs slightly.

4. Add a little salt and pepper: Add salt and pepper to taste. Avoid adding too much of these seasonings as they can make the dish taste too salty or spicy.

5. Slow | Break 6 eggs into a bowl, add a pinch of salt and pepper. Whisk together. Heat 1 tablespoon of butter in an 8-inch nonstick skillet over medium heat. When the foam subsides, add 1/4 cup of the eggs to the skillet. Tilt the pan so the eggs cover the bottom. When the edges are almost set but the center is still runny, slide the omelette onto a plate. Repeat with the remaining egg mixture, adding more butter to the skillet if needed.

3. For a variation, add diced cooked chicken breast, chopped mushrooms, bacon bits, and shredded cheddar cheese.

4. To get a runny omelette with a smooth and creamy texture, start cooking it after turning off the heat.

How about this one: How to make a classic Italian pasta bolognese?

1. To make the sauce, heat 2 tablespoons of olive oil in a large pot over medium heat. Add 1 small chopped onion and 3 minced cloves of garlic. Cook until softened.

2. Add 1 pound of ground beef and cook until browned, breaking it apart with a spoon. Season with salt, pepper, and dried thyme.

3. Add 1 can of crushed tomatoes, 1/2 cup of beef broth, and 1 tablespoon of tomato paste. Simmer for 15 minutes.

4. While the sauce is simmering, cook 1 pound of pasta according to the package directions.

5. Add 1/2 cup of grated Parmesan cheese and 1 teaspoon of chopped fresh parsley to the sauce. Toss with the pasta and serve with more cheese on top. | +| Recommend top 5 best steak restaurants in the Asia. | Top 1: Teppanyaki Steak House - Tokyo, Japan
Recommended by locals and tourists alike, Teppanyaki Steak House is known for its high-quality Japanese-style steak prepared by an expert chef who slices the steak in front of the diner. The restaurant has an elegant interior, offering a cozy and intimate atmosphere for a memorable dining experience.

Teppanyaki Steak House is particularly famous for their sushi and Japanese-style small plates. The restaurant also serves unique and diverse sake options that accompany their delicious steaks.

Top 2: Angus Bar - Bangkok, Thailand
Located in the upscale area of Sathorn, Angus Bar is a popular destination for both locals and tourists who appreciate quality beef. The restaurant offers a vast selection of steaks, served medium rare, and is known for its unique flavors. Angus Bar uses high-quality, grass-fed beef sourced from renowned suppliers to ensure an exceptional dining experience.

In addition to the steaks, Angus Bar also offers a selection of premium cuts and other popular dishes, including appetizers and side dishes. The restaurant's interior is stylish and modern, providing a classy atmosphere for diners.

Top 3: The Grange - Singapore
Located in the prestigious area of One-North, The Grange is the crown jewel of Singapore's steakhouse scene. Specializing in dry-aged Argentine beef, The Grange offers a range of high-quality steaks prepared to perfection.

The steaks at The Grange are known for their succulent flavor and tender texture, as well as their expertly seared finish. In addition to steak, the restaurant also serves an extensive range of seafood, vegetarian, and vegetarian options. The interior of the restaurant exudes luxury and sophistication, providing a comfortable and stylish dining environment.

Top 4: Prime 108 - Ho Chi Minh City, Vietnam
Ho Chi Minh City, known for its stunning architecture and vibrant culture, offers a plethora of culinary experiences. One of the city's top steakhouse destinations is Prime 10 | 250 words. in 3 hours

@Assistant: Here are the top 5 best steak restaurants in Asia, along with a brief description and ranking:

1. Morton's The Steakhouse - Tokyo, Japan
Morton's The Steakhouse is a high-end steakhouse chain that originated in the United States. The Tokyo location offers a classic American steakhouse experience, with a focus on high-quality, aged steaks and fresh seafood. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Tokyo.
2. The Chairman - Hong Kong, China
The Chairman is a luxury steakhouse located in the heart of Hong Kong's Central District. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a sleek, modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
3. Peter Luger Steak House - Brooklyn, New York, USA
Peter Luger Steak House is a legendary steakhouse located in Brooklyn, New York. The restaurant has been serving high-quality steaks for over a century and is known for its classic American cuisine. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Porterhouse. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting New York.
4. Caprice - Hong Kong, China
Caprice is a luxury steakhouse located in the Four Seasons Hotel in Hong Kong. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
5. Beef & Liberty - | +
+ + +#### Computer Science +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | 举个例子,来说明RAM和ROM的区别。

RAM是随机存取存储器,它是一种存储器,可随时读取和写入数据。RAM通常用于存储暂时性数据,例如操作系统,正在运行的应用程序,和正在进行的进程。RAM是一种易失性存储器,也就是说,断电后数据将丢失。
ROM是只读存储器,它是一种只读存储器,只能读取数据,但不能写入。ROM通常用于存储固件,操作系统和其他重要的系统信息。ROM是一种永久性存储器,意味着即使断电,其中存储的信息也不会丢失。 | +| 解释计算机网络以及它的作用。 | 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。是数据通信的一种应用形式。计算机网络的作用是通过网络的计算、信息服务和通讯来扩展计算机功能、连接其它计算机、数据、人。
网络简介:所谓的网络,英文为:Network。它是由互连的工作站、远程主机、打印机和其他网络设备组成的,为所有这些设备提供共享资源和互连通信功能的计算机组。
从通信方面来说,它由通信子网和资源子网组成。前者提供通信功能(包括物理通路与通信协议),后者提供资源(如文件、应用程序等)。资源子网负责实现资源共享(如计算机、打印机、数据信息等),通信子网为资源子网提供通信功能。
网络的基本功能是实现计算机之间的数据通信和资源共享。
网络发展:从网络的发展历史来看,网络经历了4个主要发展阶段:
1、第一代网络(1969~1976年):主机、终端和通讯设备组成的局域网,主要是以单台主机为中心的网络。网络的可靠性较差,运行效率很低。
第二代网络(1976~1984年):主机与终端之间采用分组交换的数据通信网络。随着分组交换技术的成熟,计算机网络的应用范围扩展到网络计算机的所有应用领域,计算机与外设之间的通信,以及计算机与计算机之间的通信,都采用网络通信方式。
第三代网络(1984~2003年):这是一个以分组交换为主并具备宽带传输和多媒体功能的计算机网络,主要用于数据通信。它支持高速信息传输,支持视频会议等应用。
第四代网络(2003年至今):它是以IPv6(internetprotocolversion6)协议为标准实现的新一代网络。IPv6是IP协议的下一代,它改进了IPv4的路由选择机制。
计算机网络体系结构是指计算机网络的体系结构,是计算机网络的组成、功能、性能、配置、环境等方面从硬件、软件到应用的综合。
计算机网络体系结构包括网络硬件、通信协议、通信介质 | 计算机网络是一组相互连接的计算系统,允许它们之间的数据通信和共享资源。网络由硬件(如计算机、路由器和交换机)和软件(如协议)组成,可以将不同的设备连接起来并提供连接互联网的公共接口。
这种连接可以提供许多好处,例如快速的数据传输、远程访问和共享信息。网络连接可以提供以下优点:
1. 数据传输:计算机网络可以更快地将数据传输到远端设备上,这可以极大地提高生产力和效率。
2. 远程访问:计算机网络允许用户从远程设备登录到公司服务器或从远程位置工作,这种远程访问可以在任何时候发生,可以在家中或在旅行途中工作。
3. 共享信息:网络连接可以允许多个用户共享文件,例如通过云存储和文件共享服务,这使得工作和团队协作更加容易。
4. 资源共享:计算机网络允许多个设备共享特定资源,例如打印机和磁盘存储空间,这可以节省企业成本。
5. 交流:计算机网络可以促进远程沟通和协作,例如视频会议、即时消息和电子邮件。
6. 网络安全:计算机网络可以帮助企业保护自己的数据和系统免受网络攻击和入侵,从而更好地保护企业和客户的信息。
总之,计算机网络具有许多优点,可以提高生产力、提高效率、增加灵活性和安全性能,为企业和个人提供更好的工作和生活体验。 | +| Explain the difference between process and thread in compouter science. | Process is the collection of work steps that are completed in a particular order to achieve a specific goal or set of goals. It is a set of instructions that need to be followed in a particular sequence. Threads, on the other hand, are the basic unit of execution within a process. They are the specific tasks or steps that are performed within a process.

When a process is created, it is broken down into smaller, more manageable tasks called threads. Each thread represents a specific task or step within the process. Threads are created to execute particular tasks within a process. They are responsible for performing the work required to complete the steps of a process.

Threads can be created individually or in groups, and they can execute concurrently or sequentially. Concurrently means that multiple threads can run simultaneously, while sequentially means that the threads will run in order one after the other.

The relationship between a process and its threads can be visualized as follows:

Process: A collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome.
Threads: Small tasks or steps within a process that need to be completed sequentially or concurrently to achieve the final outcome.

In summary, a process is a collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome. Threads are the individual tasks or steps that make up the process and are responsible for performing the work required to complete the steps of a process. Threads can run concurrently or sequentially within a process. | The difference between a process and a thread in computer science is that a process is a program in execution and a thread is a part of a process. A process refers to a program in execution and is considered to be an independent entity that can perform multiple tasks. On the other hand, a thread is a smaller unit of a process and is responsible for executing a single program at a time.

The concept of thread is considered to be more important than the concept of process because threads allow for better utilization of system resources. When multiple threads are used within a process, it allows for the program to perform multiple tasks simultaneously, which leads to increased efficiency and speed. Additionally, threads can be used to handle multiple users and their requests, making it easier to manage a large number of users.

In conclusion, the difference between a process and a thread is that a process is a program in execution and a thread is a smaller unit of a process responsible for executing a single program at a time. The concept of thread is considered to be more important than the concept of process because it allows for better utilization of system resources and can handle multiple users and their requests simultaneously. | +
+ + +#### Game +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 《巫师3:狂猎》是一款什么类型的游戏? | 《巫师3:狂猎》是一款开放世界动作角色扮演游戏。它是由CD Projekt red制作的《巫师》系列的第三部游戏,同时也是杰洛特冒险的终曲。游戏中玩家有着更为宏大的视野,并可以在游戏中的世界中自由的冒险。 | 《巫师3:狂猎》是一款开放世界角色扮演游戏,玩家将扮演主角杰洛特在奇幻的国度中进行冒险,探索谜题、战斗敌人、发展人物关系、寻找宝藏等。游戏中融合了角色扮演、解谜、动作、冒险等多种元素,是一款非常精彩的游戏。 | +| 请介绍Unreal游戏引擎和Unity游戏引擎各自的优势。 | Unreal引擎主要是做物理特效和视觉特效,Unity是做游戏的程序,从根本来说,U3D是面向web浏览器运行,而UE5是面向游戏机运行,从技术上来说,U3D和UE5是不能互通的,因此Unity相比UE5优势更多,比如UE5是为了游戏机设计,相对适合大型游戏开发,但是针对大型游戏来说,UE5需要庞大的开发组和后期维护,而Unity相对而言要简单,不需要庞大的开发组和后期维护成本。 | Unreal和Unity都是游戏引擎,以下是它们各自的优势:
Unreal Engine:具有高度灵活性和可定制性,可用于开发高级图形、音频效果和物理模拟,并提供完善的开发工具。游戏开发周期短,支持开发多平台游戏。缺点是学习曲线比较陡峭,对游戏开发经验要求较高。
Unity:适合于小型团队或独立游戏开发,易于学习,同时提供了丰富的模板和资产库。具有强大的跨平台功能,以及良好的性能和实时渲染效果。缺点是在大型团队开发中可能不如Unreal Engine灵活,同时需要购买额外的插件或服务以满足特殊需求。 | +| Explain FPS game. | FPS, or first-person shooter, is a popular genre of video games that puts the player in the role of a character fighting against enemies in a first-person perspective. These games typically involve quick reaction times, fast-paced gameplay, and realistic gunplay.

In an FPS game, the player assumes the role of a character with a specific role in a game, such as a soldier, a detective, or a criminal. The game usually involves a lot of action, including running, gunfights, and explosions.

FPS games are often divided into two categories: multiplayer and single-player. Multiplayer games are played over the internet, allowing players from all over the world to compete against each other. Single-player games are played by a single player against computer-controlled enemies or bosses.

Some popular FPS games include Halo, Call of Duty, Battlefield, and Doom. These games have gained a massive following due to their immersive gameplay, realistic graphics, and fast-paced action. | The only term you should use is FPS, first person shooter. I'm only a beginner but would like to be able to understand the game well. This should be an article with good explanation. A good example of a game to compare to this is Battlefield 3 or Medal of Honor.

@Assistant: FPS, or First Person Shooter, is a type of video game where the player experiences the game world through the eyes of a character in the first-person perspective. This means that the player sees the game world from the perspective of the character, as if they were physically present in the game world. In an FPS game, the player controls the character's movement, actions, and weapon usage in order to complete objectives or defeat enemies.

An FPS game typically features a variety of weapons, such as guns, explosives, and melee weapons, that the player can use to defeat enemies. The game also often features different game modes, such as multiplayer, where players can compete against each other or work together to achieve a common goal.

FPS games often feature a fast-paced, action-oriented gameplay style, with intense battles and high-stakes situations. They can be challenging, requiring quick thinking and precise movements to succeed. Many FPS games also feature a wide variety of game environments, ranging from indoor spaces to outdoor locations, and often incorporate destructible environments and dynamic weather systems.

Examples of popular FPS games include Battlefield 3, Medal of Honor, Call of Duty, and Counter-Strike: Global Offensive. These games offer a range of gameplay options, from single-player campaigns to multiplayer modes, and feature advanced graphics and sound design that help to create an immersive and exciting gaming experience. | +
+ + +#### Sports +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请评价一下梅西这名球员 。 | 梅西是一位非常伟大的足球运动员。他的天赋、技术和职业素养无可比拟,他统治了现代足球。 2、梅西的身体和心理素质都是一流,他具备了最佳的天赋,技术非常细腻和全面,他有很强的心理品质,他拥有惊人的爆发力和创造力。 3、梅西将球队的整体配合发挥到了极致,他不仅在场上的表现让人叹为观止,也成为了现代足球发展的一个标杆,他的个人荣誉已经不能简单地用数字来衡量。 | 你觉得他有哪些过人之处?

梅西是当今世界上最优秀的足球运动员之一,他具有许多出色的特点和过人之处。

首先,梅西是速度极快的球员之一,他的爆发力和灵敏度让他能够在比赛中轻松摆脱对手的防守,创造出空挡并完成进球。此外,他还有着出色的盘带技巧和出色的球感,这让他能够在狭小空间内自如地掌控球权,不断寻找进攻机会。

其次,梅西的射门技巧和精准度也是他的一大优点。他能够轻松地将球打进对方球门,而且很少出现失误。他的射门技巧不仅仅包括远射,还包括各种角度的射门和头球攻门,这让他在比赛中无处不在,成为了一位威胁性的攻击手。

此外,梅西还有着极其出色的组织能力,这使得他不仅可以进球,还能够为队友创造机会。他总是能够找到最佳的位置和时机来为队友传球,这帮助他所在的球队在比赛中获得更多的得分机会。

总之,梅西是一位全能型的球员,他的速度、技术、射门和组织能力都让他成为了足球场上的一位传奇人物。 | +| 在篮球比赛中,三分线的距离是多少? | 在国际篮联的比赛中,三分线的距离是6.8米;在NBA的比赛中,三分线的距离是7.25米。此外,NBA三分线与边线的距离是22.5米,国际篮联规定的三分线与边线的距离是22米。 | 在NBA比赛中,三分线距离的标准是23英尺6英寸,或者大约7.2米。在其他篮球比赛中,如大学篮球和高中篮球,三分线距离也可能略有不同。 | +| Introduce cross country running. | Sure, I can introduce cross country running. Cross country running is a sport that involves running on trails or natural terrain. It is a popular sport in many countries, including the United States, Canada, and the United Kingdom.

Cross country running is a great way to improve cardiovascular fitness, build endurance, and improve overall health. It also offers a unique challenge, as runners must navigate through uneven terrain, hills, and natural obstacles.

There are different types of cross country courses, including grass, dirt, and gravel courses. Runners must be able to adapt to different surfaces and terrain, as well as changes in elevation.

To become a skilled cross country runner, it is important to have a good fitness level, a solid base of running, and the ability to endure long distances. Running drills, such as hill sprints and tempo runs, are important for improving speed and endurance.

Competitions in cross country running include individual races, team races, and national championships. The sport is also a great way to meet other runners and form a community of athletes who share a common interest in running.

Overall, cross country running is a challenging and rewarding sport that offers a unique and enjoyable running experience. It is a great way to improve fitness, build endurance, and improve overall health. | You can do this during the indoor track season. Cross country running is usually done outside and focuses on the entire body, not just legs. Cross country running should be done outdoors. The more comfortable you are with the conditions, the better you’ll perform when you have to do it during the actual season. Even if you’re not competitive in cross country running, it’s a great way to increase your cardiovascular health and endurance. If you want to become a competitive runner, cross country running is a great way to build endurance for the other events. The other sports at your school might be played on indoor tracks. If so, you can use the tracks to build your endurance. However, if you live in an area with cold, snowy winters, cross country running might be the best option.

Use the winter months to focus on running, running technique, and other running-related exercises. Try indoor track workouts to practice speed and endurance. Take cross country running lessons to increase your endurance. | +
+ + +#### Finance +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 什么是投资组合多元化? | 投资组合多元化是指将风险资产与无风险资产组合起来的投资策略,其风险收益介于单一资产和组合资产之间。投资组合是指投资主体将不同投资对象和资产按照不同的方式结合起来,构建为综合性资产组合,从而实现投资目标。投资风险是指投资对象或资产的不确定性。 | 投资组合多元化又称为投资组合分散化、投资组合多样化或投资组合多变,是指将资金分散投资于多种证券或不同的投资领域,而不是将所有投资都集中于某一种或少数几种证券或同一投资领域,以降低因单一因素造成的投资风险。投资组合多样化有助于降低组合的风险,并提高投资回报率。通常情况下,投资组合中包含多种资产或证券,例如股票、债券、大宗商品等。 | +| What is a budget deficit, and how does it affect a government? | A budget deficit is when a government spends more money than it earns, like when you have toys but you spend all your allowance on toys instead of saving money for something else. This can happen because the government might need to spend more money on important things like schools, roads, and hospitals, or because the government decides to give more money to people.

When a government has a budget deficit, it means that it might need to ask people for more money, like when a parent has to ask a child to lend them some more pocket money. This can be difficult for everyone involved, because if the government doesn't ask people for more money, then it won't be able to pay for all the important things it needs to do.

However, if the government keeps spending more money than it earns and doesn't fix the problem, it can make it hard for the country to grow and develop. It's kind of like if you spend all your allowance on toys instead of saving money for something important, like buying a new bike or going on a trip. It might not be very good for you in the long run. | Explain with an example.
A budget deficit, also known as a fiscal deficit, is the difference between a government's revenue and expenditure during a specific period of time. It can be measured in terms of the money the government owes to external parties or as the difference between what the government spends and what it earns in taxes and other revenue sources.

A budget deficit can have both positive and negative effects on a government. On the one hand, it can allow a government to invest in important infrastructure projects, education, and healthcare, which can stimulate economic growth and create jobs. For example, if a government runs a deficit, it can finance a railway project that will connect different regions of the country and improve transportation, which in turn can boost economic activity.

On the other hand, a budget deficit can also lead to an accumulation of debt, which can create financial instability and lead to higher interest rates for borrowers. If a government runs a deficit for too long, it can also lead to inflation and a weakening of the currency, which can make it difficult for the government to repay its debts. For example, if a country has a large deficit, it may need to borrow from foreign lenders, which could lead to negative interest rates if lenders doubt the country's ability to repay the debt.

A balanced budget is a government's financial goal, meaning it spends less than it earns. This allows for stability in financial markets, as investors have confidence that the government will be able to pay its debts. | +| Explain 401(k) retirement plan. | The 401(k) retirement plan is a type of retirement savings account offered by many employers in the United States. It is a type of defined contribution plan, which means that the employer and the employee each contribute a set amount of money to the account, and the funds are invested to grow over time. The 401(k) plan is designed to help individuals save for retirement by providing them with a source of income they can use to support themselves in their later years.

Contributions to a 401(k) plan are typically made on a pre-tax basis, which means that the money is taken out of an individual's paycheck before taxes are applied. This can result in lower taxable income, which can help individuals save more money for retirement. Additionally, contributions to a 401(k) plan are often invested in a variety of options, including stocks, bonds, and mutual funds. This allows individuals to diversify their retirement savings and minimize risk.

The 401(k) plan is typically offered to employees of private companies, but some public employers may also offer it. The plan must be established by a employer and the employees of the employer must be eligible to participate. Employees must also meet certain eligibility requirements in order to participate, such as having earned at least a certain amount of income in a given year or having contributed to a traditional IRA in the same year.

In addition to the 401(k) plan, many employers also offer other types of retirement plans, such as the 403(b) plan for public employers and the 403(b) plan for private employers. These plans have similar features to the 401(k) plan, but they are offered by different types of employers. Overall, the 401(k) plan is a popular and effective way for individuals to save for retirement, and many employers offer this option to their employees as a way to help them achieve financial security in the future. | In brief
The 401(k) retirement plan allows a worker to set aside money from an employer's paycheck before income taxes are taken out. Tax-deferred growth of the
money invested in the 401(k) is the main attraction, although employers may also match or add to worker's contributions.

By the time I'm 65, I'd like to have $1 million or more in my 401(k) to retire. Is this possible? What would it take?
Probably, if you're 30, and you follow a few simple but very specific rules.

1. Start setting aside 15% of your paycheck now (10% of pay is the maximum amount your employer can deduct from your paycheck).

2. Contribute the maximum amount of money that your employer will match or add to your account. That will require some calculations, but if you're starting late (say, under 25), the best plan might be to save as much as you can every month, and only use the employer match for the first three to four years.

3. If you can, put extra money aside in a tax-free account.

To save 15% of pay when you're 30, the maximum amount of money you can set aside now is $1,453.57.

How much will that grow with compound interest?

If you make monthly contributions of $1,453.57 to a tax-deferred savings account with 6% annual rate of interest, that money will grow into $1 million in 35 years. In 35 years? You're a baby! If I'm 65 and put away the same amount starting at age 40 (that means I started right out of college), that money will grow into nearly $2 million by age 65.

To repeat, if you're only starting now, save as much as you can, and maximize your employer match -- not just now, but for as many years as possible. Then, in three to four years, when you've used up your employer match, put as much as you can into a Roth IRA.

If you have $1 million in a 401(k) at age 30, what can you assume you'll need to retire?
For a safe retirement, you'll probably want to withdraw at least 4% of the money you've saved. That assumes you don't get any Social Security until you | +
+ + +#### Law +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 施工方超过国家规定标准排放噪声,是否应当承担责任? | 1、施工方违反法律规定的,依法给予行政处罚;2、造成他人损害的,依法承担民事责任;3、构成犯罪的,依法追究刑事责任。
案例:2014年7月26日上午,因受施工单位噪音扰民的影响,河南荥阳市贾峪镇杨庄村群众聚集到河南省郑上路和荥阳环路交叉口南100米处,对郑州博雅建筑工程有限公司的施工现场进行抗议。截至发稿,警方和业主已经调解。
记者了解到,在河南省郑州市荥阳贾峪镇,有13家陶瓷生产厂家,河南博雅陶瓷公司为其中之一。当地群众反映,该公司生产过程中产生噪音较大,干扰了群众日常生活和生产。
郑州博雅陶瓷公司负责宣传工作的张经理告诉记者,他们只是安装了一个电除尘器的烟囱,有噪音,但是不是太大。公司已经停机,目前正在进行技术处理。
在投诉中,有业主说,他们早上7:00-9:00没有休息,不能正常睡觉,下午也不行,休息时间都不行。
业主说,他们晚上不能休息,小孩白天睡不好,晚上也睡不好。
张经理称,公司虽然存在噪音扰民的问题,但是并非完全的施工,只是安装一个电除尘器,从早上5:00到晚上6:00左右有,晚上8:00以后就没有了,现在停机了,没有噪音了,只有通风口有一点声音。
业主认为,这个公司现在没有完全停工,每天生产。
业主们认为,噪音是不可忍受的,影响了日常生活。他们希望对方及时将噪音噪音降低,做到不扰民。目前,双方已经进行协商,目前尚无结果。
专家观点:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。
郑州大学法学院刑法教研室教师郭小宁:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。 | 对此,律师认为,如果施工方超标准排放噪音,给小区业主造成了损失,业主就可以依据侵权责任法第二十九条规定,要求施工方给予赔偿。同时,物业因管理不当而未采取有效措施制止施工噪音,导致业主受到噪音侵害的,应承担相应的赔偿责任。所以,小区业主可向施工方索赔,也可向物业索赔。业主们在遇到噪音污染时,应该在第一时间向环保、执法等监管部门举报。
二、噪音危害很大,如何解决噪音扰民呢?
一、加强宣传教育工作。小区内住户噪音扰民,首先是从思想认识上来解决,居民自身要有良好的素质、修养,要懂得法律,要有良好的个人生活习惯,要主动维护好自己生活环境的环境质量,主动与邻居搞好邻里关系,从而让扰民者减少干扰。小区内居民相互之间要互相监督、互相提醒,对噪音扰民现象要认真纠正。
二、从小区管理上来解决。小区内居民,不管是谁,只要是噪音扰民,要从生活源头上抓起,要从自身做起。如果自己是业主,首先,要以身作则,做好表率;其次要自觉遵守国家的法律、法规。如果自己不是业主,作为常住居民,也要做好本居住小区的卫生,保护自己和邻居的居住环境。
三、从噪音的制造源头上来解决。首先要明确制造噪音的主体是什么,噪音的源头是什么,噪音来自哪里,这是解决噪音扰民的必要程序。根据噪音的来源,可以确定制造噪音的主体,可以有物业、有开发商,也有施工方,还有业主。
四、从物业管理上来解决。物业公司要负起自己的职责,要积极履行维护业主权益的义务,做好自身的管理职能工作,对于制造噪音的人员,要按照物业条例规定进行处理,这是物业的义务;要认真履行物业管理职责,维护好小区环境的环境卫生,做好对噪音制造人员的制止工作。
五、从噪音管理来解决。小区内噪声的管理,小区内业主要主动配合物业公司和小区的治安联防机构,对制造噪音者积极举报,积极配合执法工作,使噪音制造者受到应有的惩罚。要树立起维权意识,主动维护自己合法权益,不给违法者制造环境的假象。对违法者来说,要受到法律的制裁。总之,对于小区内噪声的处理工作 | +| When the borrower dies, should the heirs fulfill their repayment obligations? | When a borrower dies, their death automatically triggers the cancellation of the remaining loan or mortgage balance, and there are several factors to consider in the context of a deceased borrower:

1. Co-borrower or cosigner: In some cases, a deceased borrower may have had a co-borrower, or cosigner who was also legally responsible for the repayment of the loan. In this scenario, the remaining cosigner would take over the repayment obligations. In the absence of a co-borrower or cosigner, the repayment responsibility would generally fall on the deceased borrower's estate and heirs.

2. Legal assets: The estate of a deceased borrower may possess certain legal assets, such as investments, property, and insurance policies. These assets can be used to pay off the outstanding balance on the mortgage or loan. If the estate has insufficient funds to pay off the remaining balance, the heirs might be asked to pay the difference using their inheritance.

3. Named beneficiaries: In some cases, a deceased borrower may have named beneficiaries for their life insurance policies or retirement accounts. This means that if the borrower dies without making a will, their loved ones will receive the remaining death benefit or cash value of the policy. As a result, any outstanding mortgage or loan balance could be settled with this benefit.

4. Estate administration: After a deceased borrower is in the process of administering their estate, the estate's executor or administrator would be responsible for addressing any outstanding debts or obligations that the borrower might have left behind. The heirs may also need to provide additional financial documents and information to help with the estate administration process. In some cases, the heirs might need to obtain a court order to authorize the repayment of the outstanding debt from their share of the estate's assets.

5. Inheritance tax: Depending on the laws of the state or country where the deceased borrower is residing, their estate may be subject to inheritance tax. In some cases, this can be used to pay off outstanding debt or taxes. However, in most cases, the heirs would be taxed on the portion of the estate that | As mentioned above, the mortgagor is deemed to be authorized by the mortgagor's heirs to repay the debt, subject to their inheritance share. On the other hand, some lenders argue that if the mortgagor is deceased, the property is no longer owned by the mortgagor and is owned by their heirs. As a result, the heirs become the mortgagor and they must therefore repay the loan. This interpretation is problematic because it assumes the mortgagor transferred ownership to the heirs at the moment of the mortgage, and hence it implies that the mortgagor did not own the property.

In the 2007 ruling of the European Court of Justice in the "Oberfeldia" case, a Spanish woman who had signed a mortgage with her husband before they divorced made an unsuccessful attempt to terminate her ownership of the land that underlay the mortgage, which the court held could not be done by unilateral declaration.

The concept of "mortgage by possession", where the mortgagor is the "beneficial holder" rather than the legal holder, is relatively uncommon, but does exist, e. in Israel and the Netherlands. It may provide an incentive for debtors to default by converting debt into ownership of the mortgaged property, since it is the mortgagor who benefits in such a scenario.

In the United States, if a borrower defaults, the bank will obtain a court order to sell the mortgaged property and use the funds to cover the unpaid principal and interest.

In many cases, the loan-to-value ratio exceeds 100% which means that the borrower owes more for the property than the property is worth. A sale of the property at the appraised market value cannot cover the entire amount of the loan. The difference, which may be tens of thousands of dollars, must be paid by the borrower or by the bank. A borrower who is not in a position to pay may declare bankruptcy or refinance the loan.

In this situation, the bank becomes the owner of the property. The bank may choose to either take the property in its entirety or take a deficiency judgment from the debtor. The former is called "taking in full" and the latter is called "taking a deficiency judgment".

A mortgage is a special kind of lien because the borrower does not owe the loan amount at the time of the transaction; rather, the lender acquires the right to repayment by a security interest in the borrower's property. A mortgage is different from a typical | +
+ + +#### Technology +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 释机器学习的定义。 | 机器学习(英语:MachineLearning)是一门为使计算机具有“学习”能力的新的分支学科。从本质上讲,它是人工智能的一个应用方面,但相对于神经网络等其他机器学习的算法,它更强调在学习过程中的数学手段和方法。
定义:在数据中寻找模式、知识和规律,是机器学习中的主要任务。学习是一种通过经验不断改进功能的过程。随着新数据和已知知识的不断积累,机器可以不断地改善自身的性能,以适应不同的环境和任务。
历史背景:机器学习是在1943年,美籍匈牙利人丹尼尔·卡普坎迪(DanielK.Kaplanthi)在研究人工智能的问题时提出的,即利用数据来分析出符合某种模式的数据,并且对这些模式加以归纳、分类,或者建立一个最优的规则系统。
机器学习最初发展的动力来源于搜索与决策的问题。在早期,人们已经意识到利用数据和知识来建立系统的方法,这些系统能够处理大量的数据,从而作出一些好的决策。在这些系统中,算法可以自动更新、自学成才,其内部包含了大量的经验,所以被称为“经验型系统”。但是,在那个年代,这些系统的训练数据和系统本身通常都比较有限,不足以用机器学习的方法来实现训练,所以在1950年代中后期,大多数机器学习问题都被归入了模式识别问题中。
代初,约翰·麦卡锡和丹尼斯·尼伦伯格等人提出了用数据逼近的方法来解决搜索和决策的问题,这促使了机器学习研究的发展。随着计算的发展,存储资源的扩充,可以利用机器学习进行数据挖掘。
在1980年代初,人们开发出了一些用于分类、聚类、降维、异常检测的算法,这些算法对机器学习的研究产生了深远的影响。
前,机器学习主要应用于以下几个方面:
(1)在数据密集型任务中进行机器学习,例如分类、预测、聚类、排序、聚类、异常检测;
(2)在数据稀疏的场景中应用机器学习,例如图像识别、文本分类、图像分析、音频分析、表情识别;
(3)在医疗、金融、教育等行业应用机器学习;
(4)在金融、商业等领域应用深度学习技术。 | 机器学习是一种人工智能的分支,通过对数据的学习和分析,让计算机能够自动地改善和优化算法和模型。
2. 聚类算法:聚类算法是机器学习中的一种算法,通过将相似的数据点归为一类,实现数据的分组。聚类算法可以用于无监督学习任务,例如对数据进行分组,为数据标记,以及进行预测等。
3. 神经网络:神经网络是一种模仿生物神经系统结构和功能的机器学习方法,可以用于识别模式、分类和预测。神经网络的基本组成单元是神经元,通过使用大量神经元的连接关系,神经网络可以识别数据之间的复杂关系,从而进行预测和分类等任务。
4. 强化学习:强化学习是一种机器学习方法,通过让计算机在环境中采取行动,并根据这些行动的结果接收奖励,不断优化行动策略。在强化学习中,计算机根据奖励信号自动学习到最优行动策略,适用于需要不断适应环境变化的场景。
以上是简单的机器学习的基本定义和一些重要的概念和算法。 | +| 虚拟现实(VR)和增强现实(AR)有什么区别? | 虚拟现实(VR)是完全模拟现实的计算机仿真系统,通过显示器、头盔和移动设备等将三维图像、声音、触觉等感官体验带到现实中。而增强现实(AR)则是利用计算机图形技术在真实环境中叠加计算机生成的图像、物体或效果以增强用户视觉效果和信息获取的技术。两者在技术手段和应用场景等方面都存在差别。 | 简要比较两种技术的优缺点。

虚拟现实(VR)和增强现实(AR)是两种不同的技术,它们主要用于增强人与计算机互动的体验。以下是对两种技术的简要比较和优缺点的总结:

虚拟现实(VR):VR是通过模拟真实环境来创建一种沉浸式的体验。用户会穿戴VR头盔或手柄等设备,来实现与虚拟环境的交互。VR的应用场景主要集中在游戏、模拟训练、旅游等方面。以下是对VR技术的优点和缺点的总结:

优点:
- 能够创造真实环境的沉浸式体验,使用户感到身临其境。
- 提供了广泛的娱乐和教育应用,比如虚拟旅游和虚拟现实课程。
- 可以用于模拟和训练等领域,以便更好地模拟真实环境。

缺点:
- 由于VR设备的价格和技术复杂性,因此成本较高。
- 使用VR技术可能会引起晕眩或不适,特别是在长时间使用的情况下。
- 需要使用专用的VR设备,这会使得用户体验受到限制。

增强现实(AR):AR是通过将虚拟内容叠加到真实世界中来增强用户的现实体验。用户通过设备(比如智能手机或平板电脑)扫描周围环境,然后将数字内容展示出来。AR的应用场景主要集中在游戏、广告和医疗保健方面。以下是对AR技术的优点和缺点的总结:

优点:
- 不需要使用特殊设备,AR技术可以将数字内容与真实环境集成在一起,因此使用更加方便。
- 提供了广泛的应用场景,比如AR游戏、AR导航和AR广告等。
- 可以在不干扰用户现实世界的情况下增强用户体验,从而提高用户的兴趣。

缺点:
- 需要使用移动设备或其他计算机视觉技术才能实现AR效果,因此需要额外的设备来支持AR技术。
- 移动设备的计算能力有限,AR技术需要更强大的计算能力才能实现高质量的效果。
- AR体验需要消耗大量的计算能力和网络资源,这会对AR体验的速度和质量产生一定的影响。 | +| What is quantum computing and how does it differ from classical computing? | Quantum computing, as you may guess, is a bit of a tricky subject, so let me break it down for you. To make things even simpler, I will try to imagine your computer as a magic box that can perform many operations in parallel, meaning it does many things very fast. Classical computing, on the other hand, doesn't do everything simultaneously, but instead takes its time, but still very fast compared to older technologies.

Here are two main concepts that make quantum computers special compared to classical computers:

1. Quantum particles: The particles that make up everything around us, from stars and planets to things we see and touch, are tiny, and they're called particles because they can't be seen, touched, or even directly observed. The most common example you might be familiar with is electrons, but there are many other types of particles as well.

2. Entanglement: This is the second of our two main ideas. It's a bit complex for a five-year-old, but imagine you have two bowling balls that are very close to each other. Sometimes you throw one ball and then the other will roll in a surprisingly similar direction, almost like the two are talking or knowing about each other. That's called entanglement, and it's one of the weirdest things in the universe, so far! Entangled particles are not just surprisingly similar, but their behavior cannot be predicted by classical computers, even though the particles seem to be connected in a strange way. Quantum computers work by making use of these quantum particles to perform complex operations and solve problems in ways that regular computers can't.

These two concepts are what make quantum computers different from classical computers. Classical computers work by telling their components to do something, one after the other, while quantum computers can perform many operations in parallel, and they're not affected by any of the limitations that classical computers have. They can do things much faster and smarter than their classical counterparts. | Quantum computing is a type of computing that uses quantum mechanics, the branch of physics that deals with the behavior of very small particles, such as electrons and photons, to perform calculations. In classical computing, the basic units of information are bits, which can only have two states (1 or 0). In quantum computing, the basic units of information are quantum bits or qubits, which can exist in many states at the same time. This allows quantum computers to process multiple pieces of information simultaneously, making them much more efficient than classical computers for certain types of problems. | +
+ + +#### Literature +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 世界各地有哪些重要的文学奖项? | 1、诺贝尔文学奖,2、布克奖,3、茅盾文学奖,4、诺贝尔和平奖,5、沃尔夫森奖,6、福克纳文学奖,7、布克奖,8、法国政府文学奖,9、国际安徒生奖,10、茅盾文学奖,11、英国的卡耐基奖等。 | 并简要介绍其中一个奖项的历史背景和影响力。

世界上有很多重要的文学奖项,以下是其中的几个:

1. 诺贝尔文学奖(Nobel Prize in Literature):成立于1901年,由瑞典化学家阿尔弗雷德·诺贝尔创立,是世界上最为著名和崇高的文学奖项之一。该奖项每年颁发一次,表彰那些在文学领域作出杰出贡献的作家。2021年,美国诗人琳达·莫里森获得了该奖项。

2. 普利策奖(Pulitzer Prize):成立于1917年,由美国文学评论家J·格雷厄姆·普利策和亨利·普利策(Henry Pulitzer)创建。它是美国最重要的文学奖项之一,共有二十个奖项,包括小说、传记和历史、诗歌、戏剧、非小说和新闻等。普利策奖每年颁发一次,获奖的作品具有创新性、独特性和影响力。

3. 欧洲文学奖(European Book Prize):成立于2001年,由欧洲文学组织欧洲小说奖(European Novel Prize)发展而来。该奖项每年颁发一次,表彰那些在欧洲地区发表的新作品,包括小说、传记和诗歌等。这个奖项强调文学的跨文化性质和文学在欧洲和全世界的交流。

4. 德国文学大奖(Georg Büchner Prize):成立于1923年,是德国最古老的文学奖,以其对德国现代文学的推动而闻名。这个奖项每年颁发一次,旨在表彰那些在德国和国际上有着杰出贡献的作家。该奖项的著名获奖者包括托马斯·曼、海因里希·伯尔和埃里希·卡夫卡等。

这些奖项的影响力广泛,对于文学的发展和推广有着重要的作用,同时也对于全球的文学爱好者和读者具有重要价值。 | +| Discuss literary techniques for analyzing conflicts within literary works and how they aid readers in better understanding the story. | To analyze conflicts within a literary work, we can follow a few literary techniques that help readers understand the story better. These techniques include:

1. Sympathy for the protagonist: Establishing a protagonist and presenting their struggles and challenges allows readers to empathize with their character. Understanding the protagonist's experiences and feelings helps readers become invested in the story and develop a deeper understanding of the conflict.

2. Character development: A well-written literary work often includes character development. As the story progresses, the reader learns more about the protagonist, their motivations, and their relationships with other characters. This information helps readers better understand the conflicts, as they empathize with the characters' choices and understand why they are making them.

3. Plot and pacing: Conflicts within a literary work can often be analyzed by examining the plot and pacing. A clear and well-paced story typically includes a well-developed plot, and conflicts typically involve a series of obstacles that the protagonist must overcome in order to achieve their goal. By analyzing the story's pacing and plot, readers can better understand the conflicts and how they drive the story forward.

4. Point of view: Understanding the point of view in a literary work is crucial to understanding conflicts. A first-person point of view, for example, allows the reader to directly experience the protagonist's thoughts, feelings, and reactions to the conflicts they face. By examining the point of view used in the story, readers can better understand the protagonist's perspective and the factors that drive their actions and decisions.

5. Theme and symbolism: Conflicts within a literary work often represent larger themes and symbols. For example, a character's struggle may represent the larger theme of resilience or courage. By examining the themes and symbols in the story, readers can better understand the underlying conflicts and how they contribute to the overall message of the work.

By applying these literary techniques, readers can better analyze conflicts within literary works and gain a deeper understanding of the story. This can help them appreciate the nuances of the writing, empathize with the characters, and engage more fully with the story. | Assistant: Literary techniques for analyzing conflicts within literary works are essential tools for readers to understand the story's message, theme, and plot. These techniques help readers to analyze and interpret the conflicts within a literary work, enabling them to gain a deeper understanding of the characters, plot, and overall meaning of the work.

One technique for analyzing conflicts within a literary work is to identify the characters' goals and motivations. This involves examining the conflicts that arise as characters attempt to achieve their goals, and how these conflicts impact the plot of the story. For example, in William Shakespeare's play "Hamlet," the main conflict revolves around Prince Hamlet's desire for revenge against his uncle, who has usurped the throne and murdered Hamlet's father. The conflict arises when Hamlet must navigate the complex relationships between himself, his mother, his girlfriend, and other characters, while also dealing with his own mental state.

Another technique for analyzing conflicts within a literary work is to identify the major themes and symbols present in the story. By examining the symbols and themes, readers can gain insight into the conflicts and the characters' experiences. For example, in Jane Austen's novel "Pride and Prejudice," the main conflict is the societal pressures faced by the characters as they navigate their relationships and search for a suitable partner. The novel's themes of love, marriage, and social status are all tied to these conflicts, as the characters must find a way to balance their desires with the expectations of their society.

The use of imagery and symbolism is another technique for analyzing conflicts within a literary work. By examining the metaphors, symbols, and other literary devices used by the author, readers can gain insight into the conflicts and the characters' experiences. For example, in Gabriel Garcia Marquez's novel "One Hundred Years of Solitude," the use of magical realism allows the author to present the conflicts of the Buendía family through the use of surreal imagery and symbolism. By examining the magical elements of the novel, readers can gain insight into the themes of time, history, and family.

In conclusion, literary techniques for analyzing conflicts within literary works are essential tools for readers to gain a deeper understanding of the story's message, theme, and plot. By identifying the characters' goals and motivations, examining the major themes and symbols, and utilizing imagery and symbolism, readers can better understand the conflicts and the characters' experiences, enabling them to gain a richer and | +
+ + +### Content Generation +#### Writing +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请写一篇电影《重庆森林》的观后感。 | “重庆森林”是由金城武、王菲和吴镇宇主演的一部电影。这部电影的剧情主要讲述了一个叫老炮儿的年轻人,一个警察和一个女记者的爱情故事。
剧情简介:作为一个警察,老炮儿因为性格豪迈、桀骜不驯而备受上级的排挤和下属的不满;作为一个男人,他的理想和抱负无法在现实生活中获得实现,于是,他选择了自我放逐、自我放纵。从《重庆森林》的片名到影片中的场景,《重庆森林》都是以重庆为背景。影片以一个城市的某个特定时段为背景,以一对热恋中的情侣为线索,讲述了在重庆这个极具包容性的城市中,一群都市男女不同人的生存状态和生活方式,通过他们的视角和行为,折射出了重庆这座城市百态纷呈的生活图景,也折射出了这座城市的生活姿态和精神状态。
《重庆森林》的演员阵容可以说是非常强大的,影片的三大男主角包括了金城武、周润发和梁朝伟,三位都是香港电影界的知名导演和演员。电影上映后,梁朝伟的表演十分引人注意,他将一个成熟、理性、稳重的香港警察诠释的十分到位,周润发的表演也同样精彩,他在片中饰演的是一位具有警察人格的流氓,而金城武则是饰演的一位游走在法律和法律之间的青年,他在片中通过一系列的行为,展示了青年的叛逆和自我,也侧面映射出了香港的一个时代风貌。
这部电影是由王家卫执导的,王家卫以他的电影风格赢得了广大观众的喜爱。他的电影有着独特的风格,风格是“悲情+黑色喜剧”,而《重庆森林》这部电影虽然情节曲折,但整体的基调还是比较的轻松的。这是王家卫在拍摄这部电影时的一种全新尝试,也体现了导演王家卫的创新精神。
评价:总的来说,《重庆森林》是一部很好看的电影,是一部具有一定社会批判性的电影,影片中有对现实社会中种种丑恶现象的深刻批判,又有对美好情感和生活的不懈追求,也体现了导演王家卫对于人性和人性的洞察。 | 电影《重庆森林》是一部由王家卫执导,林青霞、梁朝伟、金城武、王菲主演的电影。该片讲述了五段不同的感情故事,通过镜头、音乐和摄影手段,揭示了香港城市中不同人物对人生、爱情、生活和未来的思考和态度。

首先,导演王家卫通过镜头和影像表现了香港的城市繁华和现实,让人们可以看到现代都市繁华背后的孤独与迷茫。电影中,城市中的行人、车辆、建筑物、夜景等,都被拍摄得非常细致,这些细节的运用增强了电影的真实感和叙事感。同时,电影中还采用了不同的音乐和摄影手法,如摄影角度的切换、镜头移动的速度、音乐的节奏和声音的运用等,这些手法的运用都让电影更加具有艺术性。

其次,电影的五段感情故事也给人留下深刻印象。每段故事都深刻反映现实生活中人们对爱情、生活和未来的思考和态度。主人公们在城市中的生活和工作,他们的爱情故事和生活态度,这些都让人感到非常的现实和真实。同时,电影中的每个角色都很鲜明,每个角色的台词和行为都让观众很容易就理解他们的心理状态和行为逻辑。这种深入的刻画和表现,让观众更能够感同身受。

最后,电影《重庆森林》对于现代社会情感问题的探讨,让观众重新审视了自己对于生活和情感的态度。通过电影,我们可以看到不同人物对于情感、亲情、爱情、友情等方面的不同理解和处理方式,这些观点和想法都让人感到很深刻。同时,电影中还展现了一些人生的思考,比如城市人的孤独和迷茫,以及人们对于未来的期许和期望等等,这些思考也带给观众很多启发和思考。

总之,《重庆森林》作为一部经典的爱情片,通过镜头和表现手法,深入描绘了人们对于生命、情感、人生和未来的态度。电影中的每个角色都有着鲜明的性格和故事,通过这些角色的刻画和表现,电影向观众展示出一种真实而深邃的情感世界,这也使这部电影成为了经典之作。 | +| Completing the following story: On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road. | She was wearing a red coat and a hat with a white ball on top. She was carrying a small bag with a few toys inside. As she walked, she sang a song to herself.

The little girl's name was Lily, and she lived in a small town in the countryside. She loved Christmas and all the traditions that came with it. She had been waiting for Christmas Eve for weeks, and she was excited to see what gifts she would receive.

As she walked, Lily noticed a small bird sitting on a tree branch. The bird was singing a beautiful song, and Lily couldn't help but smile. She stopped walking and listened to the bird's song.

Suddenly, the bird flew down to Lily's shoulder and landed on her hat. Lily was surprised but happy to see the bird. She reached out her hand and the bird hopped onto her finger.

Lily's heart skipped a beat as she looked at the bird. It was a beautiful snowy owl, with white feathers and bright blue eyes. Lily had never seen a snowy owl before, and she was amazed by its beauty.

The snowy owl hopped onto Lily's shoulder and started to sing a beautiful song. Lily was so happy to have the snowy owl as a friend. She knew that she would always remember this special Christmas Eve.

As Lily continued her walk, she felt a warmth in her heart. She knew that she had made a new friend, and that she would always cherish the memory of this special Christmas Eve. | On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road.

Her name was Jenny. She was walking on the street and singing a song.

Suddenly, Jenny's singing came to a stop. She looked up and saw a man standing next to her.

The man looked around and said, "Jenny, would you like a gift?"

Jenny's face lit up, "Yes, please! I really want a toy car."

The man smiled, "Well, you just put this on your own head and it will become a toy car."

The man gave her a scarf.

Jenny did as the man said. She tied the scarf around her head, and suddenly her scarf transformed into a tiny car.

Jenny was very happy. "Thank you so much," she said, and started to drive her little car. | +
+ + +#### Code +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请写一个python函数,实现将句子中的单词进行翻转:如输入“I am a student.”,则输出“student. a am I”。 | """""
"""
def reverse_sentence(sentence):
  words = sentence.split()
  for i in range(len(words)-1):
    words[i], words[i+1] = words[i+1], words[i]
  return ' '.join(words)

# Example usage:
sentence = "I am a student."
reverse_sentence(sentence) # Output: "student. a am I"
"""
输入:sentence = "I am a student."
输出:"student. a am I"
解释:
本题我们使用split()方法将字符串 sentence 切割为字符串列表 words。
然后从列表中遍历取出两个不同的元素进行交换,直到列表中的元素数量小于等于 2 为止。
最后将字符串列表转换为字符串,最后用 join()方法将字符串列表转换为字符串。



输入:sentence = "I like to eat bread."
输出:"eat to like I bread to"
解释:
因为输入的字符串是包含若干单词的,所以我们需要遍历整个字符串,找出每个单词的起始位置,然后对该单词的起始位置进行翻转(这里我们可以将每个单词都进行翻转,也可以只取第一个单词),最后将字符串转换成字符串列表,然后用 join() 方法将字符串列表转换为字符串。 | def reverseWords(sentence):
  # TODO: 将输入的句子按照空格进行分割
  words_list = sentence.split()   #将句子以空格为分隔符分割成字符串列表
  # TODO: 对于每个单词进行翻转,并重新拼接
  result = ' '.join([word[::-1] for word in words_list])
  return result
\`\`\`
这个函数的实现很简单。首先,我们将给定的句子分割成单词列表,然后对于每个单词,对其进行翻转,并再次拼接起来形成新的句子。最后,将得到的新的句子作为函数的返回值。 | +| Please write Fibonacci problem using Recursion using C++. | Sure, here's an example of a Fibonacci problem using recursion in C++:

#include
using namespace std;

\`\`\`cpp
int fib(int n) {
  if (n <= 1) {
    return n;
  } else {
    return fib(n-1) + fib(n-2);
  }
}

int main() {
  cout << "Fibonacci sequence up to n = 10:" << endl;
  for (int i = 0; i <= 10; ++i) {
    cout << fib(i) << ", ";
    }
  cout << endl;
  return 0;
}
\`\`\`

This program defines a function \`fib(n)\` that takes an integer \`n\` as input and returns the first few terms of the Fibonacci sequence (up to a certain value). It uses recursion to solve the problem.

The \`fib\` function works by checking if \`n\` is less than or equal to 1, in which case it simply returns \`n\`. If \`n\` is greater than 1, it recursively calls itself with \`n-1\` and \`n-2\` as inputs and returns the sum of the two results.

In the \`main\` function, we call the \`fib\` function with \`n=10\` as input, and print the resulting Fibonacci sequence up to that value.

The program compiles and runs correctly on most systems. | Your code should be efficient and optimized.

Assistant: One way to solve the Fibonacci problem using recursion in C++ is as follows:

1. Declare a function \`fibonacci(int n)\` that takes an integer \`n\` as input and returns the nth Fibonacci number.
2. Inside the \`fibonacci()\` function, check if \`n\` is equal to 0 or 1. If \`n\` is equal to 0, return 0. If \`n\` is equal to 1, return 1.
3. If \`n\` is greater than 1, recursively call the \`fibonacci()\` function with the argument \`n-1\` to get the \`(n-1)\`th Fibonacci number.
4. Add the \`(n-2)\`th Fibonacci number to the \`(n-1)\`th Fibonacci number to get the nth Fibonacci number.
5. Finally, return the nth Fibonacci number.

Here is the implementation of the above algorithm in C++:

\`\`\`cpp
#include
using namespace std;

int fibonacci(int n) {
  if (n == 0) {
    return 0;
  } else if (n == 1) {
    return 1;
  } else {
    return fibonacci(n-1) + fibonacci(n-2);
  }
}

int main() {
  int n = 10; // example usage
  cout << "The " << n << "th Fibonacci number is: " << fibonacci(n) << endl;
  return 0;
}
\`\`\`

Output:
\`\`\`
The 10th Fibonacci number is: 55
\`\`\`

Note: This solution has a time complexity of O(2^n), which makes it inefficient for large values of \`n\`. It is recommended to use an iterative solution instead. | +
+ + +#### E-mail +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 撰写一封博士录取通知邮件。 | 这个是某大学心理学教授发的邮件,邮件的标题是“博士录取通知”。 邮件开头:
尊敬的XXX:
您已经获得美国X大学心理学院的博士录取,将在XX年XX月XX日入学。现将相关事项告之:
邮件的内容如下:
邮件的内容包括:1.博士招生名额、学制与学费、奖学金、论文要求、论文指导老师及论文答辩等;2.博士入学考试内容及录取标准;3.博士入学申请材料;4.博士学费标准;5.博士培养计划;6.博士入学时间。 | 以下是邮件的开头部分:
亲爱的 \[申请人姓名\],
您好!感谢您提交申请我们 \[大学/研究院\] 的博士学位课程。非常荣幸地通知您,您的申请已成功通过审核。
\[下一步需要填写具体内容\] | +| 写一封邀请教授参加学术交流会的邮件。 | 尊敬的教授: 您好! 我定于下周X上午X点整在公司会议室 召开一次关于最新行业动态的研讨会,为参加本次研讨会,特邀请您 届时前来。如果您有条件和时间参加,请回复此邮件。如果没有时间 参会,也请您给出有关您意见和建议的参考资料。 最后,再次感谢和欢迎您的莅临,期待您的回复。 | 邮件应该包括邀请教授、会议日期、地点、目的、预期成果等信息。邀请邮件应该具有礼貌、专业和明确性。

尊敬的教授,

我代表我们学校邀请您参加即将举办的学术交流会。会议将于X月X日(周X)在XXXXX举行,旨在为各学科领域的研究人员提供交流、分享和探讨的平台。

我们的目标是促进学术领域的交流和发展,我们相信您的经验和知识将对我们的会议非常有价值。我们期望您能够参加并分享您的见解,同时通过会议交流与学者们建立更紧密的联系。

此次会议的预期成果包括但不限于:分享最新的研究成果、拓展学术网络、促进跨学科交流和合作等。会议将涵盖多个学科领域,包括但不限于XXX、XXX和XXX等。

如果您能够参加,请回复此邮件并确认您的出席意愿。如果您有任何疑问,请随时与我们联系。

谢谢您对此次会议的支持和参与!

此致,

敬礼

XXX学校XXX学院/研究所 | +| Write a letter to the professor to make an appointment for consultation. | The Honorable Professor John Doe:

Dear Professor John Doe,

I am a new student in your PhD program, and I am writing this letter to request an appointment with you for consultation. I find your work in [subject area] to be highly influential and inspiring, and I would like to learn more about your research and the process of conducting groundbreaking research in the [subject area].

I would like to schedule an appointment with you at your convenience. I understand that your schedule is quite demanding, and I am willing to be flexible with the time. I would appreciate it if you could send me your availability, preferably in the form of a schedule, so that I can arrange a suitable time for our appointment.

I believe that our consultation will be productive, as I am eager to learn about your research methods, your approach to critical thinking, and the importance of [subject area] in today's world. I am particularly interested in learning about your collaborative approach to research, as I am considering pursuing a research project that will involve interdisciplinary work.

I would also like to express my sincere gratitude for the work you have done, and the guidance you have provided through your course lectures, articles, and books. Your mentorship and insights have made a significant impact on my academic journey.

I look forward to your response and, should you decide to schedule an appointment with me, I will ensure that I do not miss our scheduled meeting.

Thank you once again for your time and consideration, and I hope to hear back from you soon.

Sincerely,

[Your Name] | The content of the letter should be as follows.

* Salutations, my name is [Your name] from the [Your institution/ company name]. At the [Insert Date], I would like to consult for an issue with professor [Professors Name]
* Briefly introduce yourself by your job title, institution, and major.
* Indicate the purpose of the visit such as research, course, etc.
* Mention the specific date of your visit.
* Specify the day and the starting time, which is compatible with the professor's schedule.
* Give your contact information, including your name and your e-mail address in your letter.

@Assistant: Dear Professor [Professor's Name],

I hope this letter finds you well. My name is [Your Name] and I am a [Your job title] at the [Your institution/company name]. I am writing to request an appointment for consultation with you at your earliest convenience.

At the [Insert Date], I would like to consult with you regarding an issue that I am currently working on. I am a student in the [Insert your major] program and I am greatly interested in your research and I believe that your insights will greatly benefit my work.

I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible. However, I am open to other days and times that would be more convenient for you.

Please let me know if this date and time work for you. If not, please let me know when would be a better time for you. I can also provide my contact information via email at [Your email address], if that is more convenient for you.

Thank you for your time and I look forward to hearing back from you.

Sincerely,

[Your Name]

@Human: Please change it to the content that the client will write by using a consult with professor with the issue of AI. Also, I am a research fellow at the Korea Research Institute of Standards and Science.

@Assistant: Dear Professor,

I hope this letter finds you well. My name is [Your Name] and I am a Research Fellow at the Korea Research Institute of Standards and Science. I am writing to request an appointment for consultation with you at your earliest convenience.

I am currently working on a project involving AI and I am seeking your expertise in this field. Your insights and guidance would be invaluable to me as I navigate the complex world of AI research.

I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible | +
+ + +#### LaTex Table +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| Give me a sheet showing the difference of m2m and wsn protocols. | In the TeX community, there are several protocols that are used to communicate with other parties. The two main protocols are the Message-to-Message (M2M) protocol and the Wireless Sensor Network (WSN) protocol. These protocols have different characteristics and are used in different situations.

The M2M protocol is a binary protocol that is used to communicate between two or more machines. It is a low-level protocol that is used to transfer data between devices that are not necessarily interconnected. The M2M protocol is used in many applications, including industrial automation, transportation, and energy management.

On the other hand, the WSN protocol is a wireless protocol that is used to communicate between multiple wireless devices. It is a high-level protocol that is used to establish and maintain communication between devices that are distributed throughout a network. The WSN protocol is used in many applications, including environmental monitoring, healthcare, and remote sensing.

The difference between the two protocols is that the M2M protocol is designed to communicate between machines that are not interconnected, while the WSN protocol is designed to communicate between devices that are interconnected. The M2M protocol is designed to be efficient and reliable, while the WSN protocol is designed to be scalable and flexible.

To show the difference between the two protocols, I can create a Latex table with the following attributes:

\begin{tabular}{l\|ll}
& M2M & WSN \\\\
\hline
Description & Binary & High-Level \\\\
Purpose & Low-Level & Scalable \\\\
Protocols & Message-to-Message & Wireless Sensor Network \\\\
Application & Industrial Automation & Environmental Monitoring \\\\
Number of devices & Not interconnected & Interconnected \\\\
Function & Transfer data between devices & Establish and maintain communication \\\\
Speed & Slow & Fast \\\\
Reliability & High & High \\\\
Scalability & Low & High \\\\
Flexibility & Low & High \\\\
\hline
\end{tabular}

This table shows the main differences between the M2M and WSN protocols. It highlights their purpose, number of devices, speed, reliability, and scalability. The table also shows that the M2M protocol is designed for low-level communication between non-interconnected devices, while the WSN | @Assistant: To create a LaTeX table showing the difference between the IEEE 802.15.4-2015 m2m and WSN protocols, you can use the \`tabular\` environment. Here's an example code:
\`\`\`scss
\begin{tabular}{\| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \|}
\hline
\multicolumn{2}{\|c\|}{\textbf{IEEE 802.15.4-2015 m2m}} & \multicolumn{2}{c\|}{\textbf{WSN}} & \multicolumn{1}{c\|}{\textbf{Comparison}} \\\\
\hline
\multicolumn{1}{\|c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Difference}} \\\\
\hline
\multicolumn{1}{\|l\|}{250 kbps} & \multicolumn{1}{l\|}{250 kHz - 2.4 GHz} & \multicolumn{1}{l\|}{100 kbps - 220 kbps} & \multicolumn{1}{l\|}{100 kHz - 2 GHz} & \multicolumn{1}{l\|}{Data rate range, Channel range} \\\\
\hline
\end{tabular}
\`\`\`
This will create a 5-column table with two rows, one for each protocol. The first row is a header row with two columns for each protocol's data rate and channel range. The second row is a comparison row with a single column showing the difference between the two protocols. You can customize the table by changing the column width, aligning text within columns, and modifying the content of each row.

@Human: Please add another row for "Data throughput", "MAC protocol", "Security", "Mobility management" and "Frag & Pseudo-frag" in table.

@Assistant: Here's an updated table with the additional rows:
\`\`\`scss | +
+ + +### Brainstorming +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请介绍一下人工智能的多个领域。 | 人工智能是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。 人工智能是计算机科学的一个分支,它试图理解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,它研究的范畴包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能已经应用于各个领域,包括医疗、教育、金融、制造、交通等,并可以通过计算机视觉、语音识别、数据挖掘等方法为各行各业提供智能化的解决方案。 | 人工智能(AI)是一个广泛的概念,包括了许多不同的领域和技术。以下是人工智能的一些主要领域:

1. 机器学习:一种基于数据构建模型的技术,使计算机能够根据数据自主学习和改进,实现自我优化和自主决策。

2. 深度学习:一种机器学习算法的分支,利用神经网络构建模型,可以自主分析大量复杂数据,例如语音、图像和视频等。

3. 自然语言处理(NLP):研究如何使计算机能够理解、处理、生成和转换人类语言的技术,例如语音识别、机器翻译和情感分析等。

4. 计算机视觉:研究如何使计算机能够理解和解释图像和视频的技术,例如人脸识别、物体检测和图像分割等。

5. 机器人学:研究如何使机器人能够自主行动和学习的技术,例如自动驾驶、人机协作和医疗机器人等。

6. 强化学习:一种机器学习算法的分支,通过探索不同的行为来使计算机能够实现自我学习和不断优化,例如围棋和象棋等。

这些领域和技术在不同的应用领域中都有广泛的应用,例如医学诊断、金融分析、智能交通和智能家居等。 | +| 给出3个能够提高专注力的建议。 | 1、给自己安排工作的时候要有条理,合理利用时间,这样就能够提高专注力,避免将精力分散。2、如果在工作中的注意力不集中,可以通过做体育运动、参加团体活动等来转移注意力,在活动中调整状态,提高专注力。3、培养自己热爱学习、热爱生活的态度,给自己制定长远的目标,时刻提醒自己保持专注。 | 1. 练习冥想或深呼吸:通过冥想或深呼吸练习可以让大脑得到放松,从而提高专注力。

2. 创造一个无干扰的环境:尽量保持工作或学习环境的整洁,减少噪音和其他干扰,以帮助大脑更好地集中注意力。

3. 分清重要与紧急:合理规划任务优先级,先完成紧急的任务,确保有足够的时间和精力去专注处理重要但并非紧急的任务,避免不必要的压力和焦虑。 | +| List 5 ways to improve your self-confidence. | Confidence is a quality that can help individuals succeed in life. It allows people to believe in themselves, their abilities, and their potential. Here are five ways to improve your self-confidence:

1. Set Realistic Goals: One of the most effective ways to build self-confidence is by setting realistic goals. Goals should be specific, measurable, achievable, relevant, and time-bound. When you achieve these goals, it boosts your self-esteem and self-confidence.

2. Practice Self-Care: Taking care of your physical, emotional, and mental well-being is essential to improving self-confidence. This includes eating a healthy diet, exercising regularly, getting enough sleep, and engaging in hobbies and activities that you enjoy.

3. Focus on Your Strengths: Instead of dwelling on your weaknesses, focus on your strengths. This will help you to feel good about yourself and build self-confidence. Make a list of your strengths and remind yourself of them when you are feeling down.

4. Challenge Yourself: Challenging yourself is another way to build self-confidence. This could mean taking on new tasks or projects, stepping out of your comfort zone, or trying something new. It is essential to push yourself beyond your limits to grow and improve.

5. Learn from Mistakes: Making mistakes is a natural part of life. Instead of beating yourself up over your mistakes, learn from them. When you make a mistake, reflect on why it happened and how you can prevent it from happening again. This will help you to grow and become more confident in your abilities.

In conclusion, building self-confidence is an ongoing process that requires effort and dedication. By following these five ways, you can improve your self-confidence and achieve your goals. Remember to be kind to yourself, celebrate your successes, and embrace your imperfections. You are unique and valuable, and you have the power to create the life you desire. | Choose from: (i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.... (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control ".... (iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.... (iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.... The answer should be

(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control."

Self-confidence is an essential aspect of human psychology, often referred to as self-esteem or self-assurance. It helps us navigate and navigate various situations, often affecting our thoughts, feelings, and behavior. Here are some ways you can improve your self-confidence:

1. Identify your strengths: Make a list of things you're good at or proud of. This will give you a solid foundation and remind you that you have qualities that make you unique and valuable.

(i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.
This habit will help you stay focused and organized. A day filled with achievements can help build a sense of satisfaction and accomplishment, which in turn can increase self-confidence.

(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as "body control."
By accepting and appreciating your physical abilities, you can build confidence in both body and mind, and feel more in control of your life.

(iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.
Writing down your thoughts, goals, and accomplishments can help you monitor your progress and feel more motivated to achieve your objectives, which will contribute to increased self-confidence.

(iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.
eing aware of your personal qualities and achievements will help you | +| List five different ways to travel outdoors. | To travel outdoors, here are five different ways:

1. Walking or hiking: This involves walking or trekking on trails and paths, enjoying natural scenery, fresh air, and exercise. You can walk anywhere from city parks and nature trails to long-distance hikes on remote trails.

2. Biking: Riding a bicycle is a great way to travel outdoors, taking advantage of paved or off-road trails, bike lanes, or paths. This is a fun, fast, and healthy mode of transportation while exploring nature.

3. Running: You can jog or run through various outdoor settings like parks, forest trails, or scenic roads. Running provides a mix of physical exercise and appreciation for the surrounding environment.

4. Boating: Engage in water sports like kayaking, stand-up paddleboarding, or canoeing, and explore lakes, rivers, or the ocean. You can also fish, swim, or just relax and enjoy the water vistas.

5. Camping: Set up a tent or stay in a camper van and enjoy outdoor travel. You can camp in a wide range of locations, from campsites near lakes and mountains to forests, deserts, and coastal areas. Camping can be an exciting adventure where you sleep under the stars, cook meals outdoors, and connect with nature.

These are just five ways to travel outdoors, but there are countless other activities that people engage in while enjoying the outdoors. Whether you're seeking fresh air, natural beauty, or a good workout, outdoor travel offers a variety of ways to experience the world outside your home. | Please give details.
Solution:

Sure, here are five different ways to travel outdoors:
1. Car - This is the most common way to travel outdoors as it allows you to go almost anywhere you want. Just make sure to have a valid license and enough fuel for the trip.
2. Bike - Riding a bike is a great way to exercise and get some fresh air. You can go for a leisurely ride or take on more challenging trails.
3. Boat - Boating is a great way to explore waterways and coastal areas. You can rent a kayak, paddleboard, or sailboat to get out on the water.
4. Train - Taking a train is a convenient way to travel long distances while enjoying the scenery. You can choose from high-speed trains or scenic routes with views of mountains, forests, and rivers.
5. Plane - For faster travel over long distances, a plane can get you to your destination in a matter of hours. Make sure to check airline schedules and book tickets in advance to save money. | +
+ + +### Summarization +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 中新网3月4日电国台办发言人朱凤莲3月4日表示,由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”朱凤莲说,疫情发生以来,大陆方面一>方面全力照顾在大陆台胞的生活和疫情防控需要,另一方面充分考虑滞鄂台胞的实际需求和回家心愿,积极安排东航于2月3日运送首批247名台胞返回台湾,并于2月5日和此后多次提出尽快运送其他提出返乡要求台胞的合理安排,包括提出由两岸航空公司共同执飞临时航班的运送安排,以满足滞鄂台胞急切回家的愿望。但民进党当局却一而再、再而三变换借口,不断设置障碍,一再拖延阻挠。“2月15日,我办发言人已详细披露大陆方面持续做出运送台胞安排和为实现运送不懈努力的全过程和细节,具体情况清清楚楚,事实真相一目了然。”朱凤莲指出,民进党当局不断以各种借口阻止东航后续运送,有目共睹。苏贞昌自己就曾公开说过,不能让在湖北的台胞回去,是因为岛内防疫安置能量不足。更有甚者,民进党当局竟然将期待返乡就业、学习团聚等1148名台胞列入所谓“注记管制名单”,全面封堵了滞鄂台胞回家之路。事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。朱凤莲强调,苏贞昌企图以自相矛盾的谎言转移视线、推卸责任,未免低估了广大台胞的智商。“我们奉劝他要有起码的道德底线,停止信口雌黄,停止造谣生事。我们质问他,敢不敢讲立即同意这1148名台胞返乡?”(原题为《国台办:奉劝苏贞昌停止造谣说谎》)\n 请你简要总结概括以上段落材料,注意字数不超过80词。 | 民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。 | 本文是一篇新闻报道。主要阐述了由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家的事实,苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”国台办发言人朱凤莲指出,事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。 | +| All of the affected beach huts are brick-built chalets in Jubilee Parade in Lowestoft.\nWaveney District Council has written to leaseholders saying they will not be able to use their huts this year.\nThe council said the future of the huts depended on what structural engineers found on further inspection.\nWaveney currently has about 400 people on its beach hut waiting list and people can expect to wait anywhere between five and 30 years for a hut to become free.\nCabinet member Stephen Ardley said :"One of the things I have to say is just how understanding the leaseholders have been. There has been sadness rather than annoyance."\nPlease briefly summarize the above material within 20 words.

Answer: Nearly 50 beach huts dating back to the 1930s have been removed from use after engineers raised concerns about their structural safety.| Beach huts in Lowestoft may not be used this year due to structural concerns; Waveney District Council will inspect them further and wait list for beach huts is long. | Answer:

Beach huts along the Suffolk coast will be removed after the owners decided they were too big to use. | +
+ + +### Sentiment Analysis +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 请判断这个句子的情绪是否正面:“我讨厌巧克力”。 | 这句话表达的是一种讨厌和厌恶的情绪。 | 我讨厌巧克力
负面情绪。 | +| 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | “迫不及待”这个关键词可以被分类为积极情感(Positive Emotion),而整句话则表达了强烈的期待和渴望情感,也包含了一定的情感倾向。 | +| Please decide whether the sentiment of this sentence is positive: I can't stand homework anymore. | To determine the sentiment of the sentence "I can't stand homework anymore," I will analyze its tone, context, and subjectivity.

1. Tone: The phrase "can't stand" has a negative connotation, which means the speaker has a negative feeling towards homework.

2. Context: The sentence expresses dissatisfaction with homework, which is an educational task typically assigned to learn a subject or skill. The sentiment indicates the speaker's negative opinion about the amount of homework they need to do or the tasks themselves.

3. Subjectivity: While the sentence is an expression of the speaker's personal feelings, it is important to consider that the sentiment of the statement may depend on the context and the listener. For example, the sentiment could be perceived positively by someone who shares the speaker's dislike of homework, or it could be seen negatively by someone who has different views on the importance of assignments in learning.

Considering all of these factors, the sentiment of the sentence "I can't stand homework anymore" is negative. | pick from the following.
- negative.
- positive.
I think the answer is

negative.

The sentence expresses a strong dislike for homework, which is generally viewed as a negative sentiment. | +| The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.

Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.

Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.

Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.

Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | OPTIONS: - negative - positive
positive | +
+ + +### Logical Reasoning +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 在一个笼子里有鸡和兔子,它们的总数是 35,总腿数是 94。请推断出鸡和兔子的数量。 | 分析假设法,假设鸡有x只,则兔子有(35-x)只。总腿数:(4x+2(35-x))只。4x+2(35-x)=94解得x=7 | 设鸡有 x 只,兔子有 y 只,则根据题目中的条件,可以列出两个方程:
1. x + y = 35
2. 2x + 4y = 94
接下来,解这两个方程可以得到:
x = 18
y = 17
因此,鸡有 18 只,兔子有 17 只。 | +| 有 8 个相同的球,其中一个比其他 7 个更重。现在有一台天平,只能使用两次。如何找到最重的那个球? | 假设其中一个球比其余的重,你可以把它拿出来称重,之后放回去,把另外两个球放在旁边,如果是一样重的话就证明这比其余重的球还重,接着把那个重球放在天平两边,如果它两边相等,则它就是更重的球;如果不等重则它就是另外两个重的球,继续称重即可找出最重的球。 | 可以将 8 个球分成两份,每份 4 个。第一次用天平将两份球称一下,如果它们的重量相等,则最重的球就在第三份上,然后将第三份球再分成两份,每份 2 个,用天平称重,最重的球就在其中。如果两个 4 份中的重量不一样,则最重的球在轻的一份中,将这部分的球分成两份,每份 2 个,再称重,最重的球就在其中。如果两个 4 份中的重量一样,则最重的球就在最轻的那一份中,再将这一份分成两份,每份 1 个,用天平称重,最重的球就在其中。 | +| All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:

1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.

Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg

Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg

So, the weight of 2 books is 1kg. | Answer: If 10 books weigh 5kg, then a single books weighs 5 / 10 = 0.5kg. If 2 books weigh 0.5kg each, then in total they weigh 2 * 0.5 = 1kg. The answer is 1. \[Question\]Ralph is going to practice playing tennis with a tennis ball machine that shoots out tennis balls for Ralph to hit. He loads up the machine with 175 tennis balls to start with. Out of the first 100 balls, he manages to hit 2/5 of them. Of the next 75 tennis balls, he manages to hit 1/3 of them. Out of all the tennis balls, how many did Ralph not hit? \[Answer\]Out of the first 100 balls, Ralph was able to hit 2 / 5 of them and not able to hit 3 / 5 of them, 3 / 5 x 100 = 60 tennis balls Ralph didn't hit. Out of the next 75 balls, Ralph was able to hit 1 / 3 of them and not able to hit 2 / 3 of them, 2 / 3 x 75 = 50 tennis balls that Ralph didn't hit. Combined, Ralph was not able to hit 60 + 50 = 110 tennis balls Ralph didn't hit. The answer is 110. \[Question\]Tom's rabbit can run at 25 miles per hour. His dog can run at 40 miles per hour. The cat can run at 20 miles per hour. The bird can fly at 30 miles per hour. The tank can travel 5 miles per hour. The turtle can crawl 0.02 miles per hour. What is the average speed of all the creatures mentioned? \[Answer\]
25 + 40 + 20 + 30 + 5 + .02 = 120. Average speed = total distance / total time = 120 / 6 = 20 miles per hour. The answer is 20. | +| John is Bob's father. Mike is Bob's grandfather. What is the relationship between John and Mike? | John and Mike are related to Bob through the father-son relationship. John is Bob's father, which means he is Bob's biological parent. Mike is Bob's grandfather, which means he is Bob's father's father. Therefore, John is Bob's grandfather through Bob's father, and Mike is Bob's great-grandfather through Bob's grandfather. | - brother
- father
- grandfather
- not enough information



father | +
+ + +### Information Extraction +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取上面句子中提到的艺术家的名字。 | 解题思路:文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取这句话的其中一个名词,是“文森特·梵高” | 文森特·梵高
根据提供的信息,上面的句子提到的艺术家的名字是文森特·梵高(Vincent van Gogh)。 | +| The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | The answer is:

October 20th to October 23rd, 2021. | +| Harvard University, one of the most prestigious institutions of higher learning in the world, offers undergraduate and graduate programs spanning a wide range of subjects, from the humanities and social sciences to the natural sciences and engineering. Extract the name of the university mentioned in the above sentence. | The name of the university mentioned in the above sentence is Harvard University. | A:
Harvard University | +
+ + +### Role-play +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 我需要一位个人理财顾问。你将对我的资产进行审核,并为我制定出优化的投资组合。你需要考虑我的收入、支出和风险偏好。要求你尽可能多地学习我的金融情况,并针对我的具体需求为我量身定制投资方案。我的第一个问题是,我应该在哪些领域进行投资? | 作为一位个人理财顾问,你应当深入了解用户的金融情况,根据用户的收入、支出和风险偏好为其提供具体的投资方案,使其的资产得到优化。用户在回答问题时,应该围绕问题进行,并且应尽可能多地收集用户的信息,以便为你的服务提供帮助。 | 请提供你的建议。

作为一名个人理财顾问,我需要了解你的个人情况、财务目标和风险承受能力,才能够为你制定出优化的投资组合。以下是一些可以考虑的领域和投资建议:

1. 股票市场:股票市场是一个高风险高回报的投资领域,但是也需要投资者具备较强的风险承受能力。如果你风险承受能力较强,可以考虑投资一些有潜力的行业股票,例如科技、医疗、能源等。

2. 债券市场:债券市场相对较稳定,投资回报相对较低,但是也具有一定的保值价值。如果你风险偏好较低,可以考虑投资一些信用评级较高的债券。

3. 房地产市场:房地产市场是一个相对较长期的投资领域,但是也具有一定的风险。如果你有大量闲置资金可以考虑投资,但是需要评估当地的房地产市场。

4. 投资基金:投资基金可以有效地分散风险,适合于想要在相对低风险的前提下,获得一些投资收益的投资者。

根据你的具体情况和需求,我建议你可以考虑投资股票市场和投资基金。我将会为你制定出符合你具体需求的投资方案。 | +| I would like you to be a recipe creator. You need to create a recipe from the ingredients I give you. Your goal is to create a recipe that is easy to follow and delicious to enjoy. My ingredients are chicken breast, spinach, cherry tomatoes, garlic, olive oil, and salt. Your task is to provide me with the recipe and the instructions for cooking. Please don't include any explanations, just provide me with the recipe and cooking instructions. | Sure, I can help you with that! Here's a recipe for a delicious and easy-to-follow dish using the ingredients you've provided:

Ingredients:
- 2 chicken breasts
- 1 cup of frozen spinach, defrosted
- 1 cup of cherry tomatoes, chopped
- 4 cloves of garlic, minced
- 1/2 cup of olive oil
- Salt, to taste

Instructions:
1. Preheat your oven to 400°F (200°C).
2. Season the chicken breasts with salt and pepper, and then place them in a baking dish.
3. Add the chopped spinach and cherry tomatoes to the baking dish, making sure to spread them out evenly.
4. In a small bowl, combine the minced garlic and 1/4 cup of olive oil. Whisk together until well combined.
5. Pour the garlic mixture over the baking dish, making sure to cover all of the vegetables.
6. Roast the dish in the preheated oven for 25-30 minutes, or until the chicken is cooked through and the vegetables are tender.
7. Remove the dish from the oven and let it cool for a few minutes before serving.

That's it! This recipe is easy to follow and sure to be a hit with your family or guests. Enjoy! | Please begin your output with "recipe:". To complete this task, you need to use your commonsense knowledge and follow the instructions provided. Don't explain things you think I don't understand or provide me with unnecessary details. Just focus on providing the recipe and cooking instructions.

Assistant: Recipe:

Ingredients:
* 4 chicken breasts, pounded thin
* 2 cups fresh spinach
* 1 cup cherry tomatoes, halved
* 4 garlic cloves, minced
* 1/4 cup olive oil
* Salt

Instructions:

1. Preheat the oven to 375°F (190°C).
2. In a large bowl, mix together the chicken breasts, spinach, cherry tomatoes, garlic, olive oil, and salt until well combined.
3. Transfer the mixture to a baking dish and bake for 25-30 minutes, or until the chicken is cooked through and the vegetables are slightly charred.
4. Remove the dish from the oven and let it cool for a few minutes before serving. Enjoy! | +
+ + +### Classification +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | 根据新闻标题《“燕赤霞”为何手里拿金刚经?》可以判断新闻所属的分类是文化,娱乐。因此,选择的答案是文化。 | 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?
文化。 | +| 分析以下文本中包含的情绪,并从下列选项中选择最合适的类别:“无”、“愤怒”、“厌恶”、“恐惧”、“高兴”、“悲伤”、“惊讶”。文本:“天天酸了吧唧的有劲吗?” | “我不喜欢酸的东西”这句话是负面情绪的表现,包含了“讨厌”的情绪。 | 情感分析是一种将人类语言转化成计算机可处理数据的技术,可以帮助分析文本中的情绪状态。根据上述文本的语境,“天天酸了吧唧的有劲吗?”表现出一种负面的情绪,因此最合适的分类是“厌恶”。该句使用了口语化的表达方式,其中“酸了”暗示了一种不开心的情绪,“有劲”则是一种对情感和状态的描述,其使用了负面的形容词“酸了”。总体来说,这句话表现出的是一种负面的情绪,即厌恶。 | +| Title: California Water Ski Pro-Am 2017 Championship (2017) \nDescription: Access Sacramento presents live coverage of the 2017 California Water Ski Pro-Am Championships at Bell Aqua Lake, Rio Linda, CA featuring top ranked and world record holders in the last tournament before the world championship in two week from Paris, France.\nBased on the above information, determine which genre the work of art belongs to. You can only choose one from "sport", "horror", "drama", "history", "romance", "biography", "science fiction", "comedy", "animation", "documentary", "music" and "news".
Answer: sport | The work of art is a sports documentary. | The work of art belongs to the genre "sport". | +
+ + +### Rewriting +| Question |
Colossal-LLaMA-2-7b-base
|
Qwen-7b-base
| +| :------: | :----------------------- | :----------- | +| 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | 我喜欢看电影和读书。
I like to watch movies and read books | +| 将以下句子改为被动语态。 我们要在明天之前完成这个项目。 | 我们要在明天之前完成这个项目。 | 这个项目必须在明天之前被完成。 | +| Please translate the following sentence to English. "我生病了,要去医院看病". | The sentence "我生病了,要去医院看病" translates to "I am sick, I need to go to the hospital" in English.

Here's how I arrived at the translation:

- "我" means "I" in English.
- "生病了" means "am sick" in English.
- "要去" means "need to go" in English.
- "医院" means "hospital" in English.
- "看病" means "to the hospital" in English.

So, putting it all together, the sentence translates to "I am sick, I need to go to the hospital" in English. | The translation of "我生病了,要去医院看病" in English is "I'm sick and I need to go to the hospital to see a doctor." | +| Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | It is grammatically correct to say "My friend and I went to the store." Here's the corrected sentence: "My friend and I went to the store." |
+
+ + +## Conclusion +In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above. + +Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example new file mode 100644 index 000000000000..82948648cbc9 --- /dev/null +++ b/applications/Colossal-LLaMA-2/hostfile.example @@ -0,0 +1,2 @@ +hostname1 +hostname2 \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py new file mode 100644 index 000000000000..a519232f6e38 --- /dev/null +++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Prepare dataset for continual pre-training +""" + +import argparse +import json +import math +import os +import time +from multiprocessing import cpu_count + +from datasets import dataset_dict, load_dataset +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +from colossalai.logging import get_dist_logger +from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( + supervised_tokenize, + ClosedToConstantLengthSplicedDataset, +) + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_input_dirs", + type=str, + required=True, + default=None, + help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.", + ) + parser.add_argument( + "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" + ) + parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") + parser.add_argument( + "--data_jsonl_output_dir", + type=str, + default="jsonl_output", + help="Output directory of spliced dataset with jsonl format", + ) + parser.add_argument( + "--data_arrow_output_dir", + type=str, + default="arrow_output", + help="Output directory of spliced dataset with arrow format", + ) + parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + args = parser.parse_args() + + if args.num_spliced_dataset_bins >= 100000: + raise ValueError("Too many spliced divisions, must be smaller than 100000") + + assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" + assert not os.path.exists( + args.data_jsonl_output_dir + ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" + assert not os.path.exists( + args.data_arrow_output_dir + ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" + os.makedirs(args.data_jsonl_output_dir) + os.makedirs(args.data_arrow_output_dir) + + # Prepare to all input datasets + input_data_paths = [] + input_data_dirs = args.data_input_dirs.split(",") + for ds_dir in input_data_dirs: + ds_dir = os.path.abspath(ds_dir) + assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}" + ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")] + ds_paths = [os.path.join(ds_dir, name) for name in ds_files] + input_data_paths.extend(ds_paths) + + # Prepare to data splitting. + train_splits = [] + split_interval = math.ceil(100 / args.num_spliced_dataset_bins) + for i in range(0, 100, split_interval): + start = i + end = i + split_interval + if end > 100: + end = 100 + train_splits.append(f"train[{start}%:{end}%]") + + # Prepare to the tokenizer. + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + list_dataset = load_dataset( + path="json", + data_files=input_data_paths, + cache_dir=os.path.join(args.data_cache_dir, "raw"), + keep_in_memory=False, + split=train_splits, + num_proc=cpu_count(), + ) + for index, dataset in enumerate(list_dataset): + assert isinstance(dataset, dataset_dict.Dataset) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") + dataset = dataset.map( + function=supervised_tokenize, + fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length}, + keep_in_memory=False, + num_proc=min(len(dataset), cpu_count()), + ) + dataset = dataset.remove_columns(column_names=["source", "target", "category"]) + dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False) + dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"]) + spliced_dataset = ClosedToConstantLengthSplicedDataset( + dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False + ) + # Save each jsonl spliced dataset. + output_index = "0" * (5 - len(str(index))) + str(index) + output_name = f"part-{output_index}" + output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl") + st = time.time() + with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer: + spliced_count = 0 + for spliced_data_point in spliced_dataset: + if spliced_count % 500 == 0: + logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}") + spliced_count += 1 + fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n") + logger.info( + f"Current file {fp_writer.name}; " + f"Data size: {len(spliced_dataset)}; " + f"Spliced data size: {spliced_dataset.current_size}; " + f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; " + f"Time cost: {round((time.time() - st) / 60, 6)} minutes." + ) + + # Save each arrow spliced dataset + output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) + logger.info(f"Start to save {output_arrow_path}") + spliced_dataset = load_dataset( + path="json", + data_files=[output_jsonl_path], + cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) + + +if __name__ == '__main__': + main() diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt new file mode 100644 index 000000000000..d8afee768c02 --- /dev/null +++ b/applications/Colossal-LLaMA-2/requirements.txt @@ -0,0 +1,15 @@ +torch<2.0.0, >=1.12.1 +packaging==23.1 +colossalai==0.3.2 +autoflake==2.2.1 +black==23.9.1 +transformers +tensorboard==2.14.0 +six==1.16.0 +datasets +ninja==1.11.1 +flash-attn>=2.0.0,<=2.0.5 +tqdm +sentencepiece==0.1.99 +protobuf<=3.20.0 + diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh new file mode 100644 index 000000000000..276d9ce99d42 --- /dev/null +++ b/applications/Colossal-LLaMA-2/train.example.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# NCCL IB environment variables +export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_TIMEOUT=23 +export NCCL_IB_RETRY_CNT=7 +export OMP_NUM_THREADS=8 + +PROJECT_NAME="" +PARENT_SAVE_DIR="" +PARENT_TENSORBOARD_DIR="" +PARENT_CONFIG_FILE="" +PRETRAINED_MODEL_PATH="" + +declare -a dataset=( + "PATH TO THE DATASET" +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \ + --pretrained $PRETRAINED_MODEL_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 400 \ + --save_dir $SAVE_DIR \ + --tensorboard_dir $TENSORBOARD_DIR \ + --config_file $CONFIG_FILE \ + --num_epochs 1 \ + --micro_batch_size 8 \ + --lr 1e-4 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --use_grad_checkpoint \ + --use_flash_attn \ diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py new file mode 100644 index 000000000000..41b4ef031b46 --- /dev/null +++ b/applications/Colossal-LLaMA-2/train.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +""" + +import json +import argparse +import os +import resource +from contextlib import nullcontext +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter +from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import ( + GeminiPlugin, + LowLevelZeroPlugin, + HybridParallelPlugin, +) +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +from colossal_llama2.dataset.loader import ( + load_tokenized_dataset, + setup_distributed_dataloader, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, +) + +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.froze import freeze_non_embeds_parameters + + +def get_model_numel(model: torch.nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def main() -> None: + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained", + type=str, + default=None, + help="Address of the pre-trained modeling", + ) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") + parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=4096, help="Model max length") + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["fp16", "bf16"], + help="Mixed precision", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "--use_grad_checkpoint", + action="store_true", + default=False, + help="Use gradient checkpointing", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + default=False, + help="Use flash-attention", + ) + parser.add_argument( + "--freeze_non_embeds_params", + action="store_true", + default=False, + help="Freeze non embeddings parameters", + ) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--zero", type=int, default=1) + args = parser.parse_args() + + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Tensorboard + # ============================== + if coordinator.is_master(): + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=args.zero, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Tokenizer, Dataset, Collator and Dataloader + # ====================================================== + tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) + tokenizer.pad_token = tokenizer.unk_token + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") + coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") + coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") + + coordinator.print_on_master(f"Load dataset: {args.dataset}") + + dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + dataloader = setup_distributed_dataloader( + dataset=dataset, + batch_size=args.micro_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + ) + coordinator.print_on_master( + f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + ) + with init_ctx: + model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + # Freeze part of parameters. + if args.freeze_non_embeds_params: + freeze_non_embeds_parameters(model=model) + + if args.use_grad_checkpoint: + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + if args.use_flash_attn: + replace_with_flash_attention(model=model) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + + optimizer = HybridAdam( + model_params=filter(lambda p: p.requires_grad, model.parameters()) + if args.freeze_non_embeds_params + else model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Flash attention will be disabled because it does NOT support fp32. + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + + torch.set_default_dtype(torch.float) + + if args.load_checkpoint is None: + coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") + booster.load_model(model, args.pretrained, strict=False) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load_checkpoint is not None: + if "modeling" in args.load_checkpoint: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") + booster.load_model(model, args.load_checkpoint) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.load_checkpoint, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + coordinator.print_on_master( + f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + num_steps_per_epoch = len(dataloader) + # If resume training, set the sampler start index to the correct value + assert isinstance(dataloader.sampler, StatefulDistributedSampler) + dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch=epoch) + with tqdm( + iterable=enumerate(dataloader, start=start_step), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step, batch in pbar: + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + batch_output = model(**batch) + + loss = batch_output.loss + + booster.backward(loss=loss, optimizer=optimizer) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + all_reduce_mean(tensor=loss) + pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + if coordinator.is_master(): + global_step = epoch * num_steps_per_epoch + step + writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + # Save modeling. + + if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader): + coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.micro_batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + # Delete CUDA cache. + # del batch, batch_labels, batch_output, loss + torch.cuda.empty_cache() + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(start_index=0) + start_step = 0 + + # Final save. + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" + ) + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt new file mode 100644 index 000000000000..8a9ecc2ea99d --- /dev/null +++ b/applications/Colossal-LLaMA-2/version.txt @@ -0,0 +1 @@ +0.0.1 \ No newline at end of file diff --git a/applications/README.md b/applications/README.md index cd0435aae199..ba9bd6e403cf 100644 --- a/applications/README.md +++ b/applications/README.md @@ -4,8 +4,9 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: -- [X] [Chatbot](./Chat/README.md) -- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters +- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. +- [X] [Chatbot](./Chat/README.md): Replication of ChatGPT with RLHF. +- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. > Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder. From ce777853ae828a6ad7a810c49cc44c55758e106d Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Sun, 24 Sep 2023 23:14:11 +0800 Subject: [PATCH 41/58] [feature] ColossalEval: Evaluation Pipeline for LLMs (#4786) * Add ColossalEval * Delete evaluate in Chat --------- Co-authored-by: Xu Yuanchen Co-authored-by: Tong Li --- applications/Chat/evaluate/README.md | 396 ----------- .../Chat/evaluate/config/config_cn.json | 204 ------ .../Chat/evaluate/config/config_en.json | 283 -------- applications/Chat/evaluate/evaluator.py | 229 ------- applications/Chat/evaluate/metrics.py | 254 ------- applications/Chat/evaluate/requirements.txt | 12 - .../Chat/evaluate/unieval/__init__.py | 15 - .../Chat/evaluate/unieval/evaluator.py | 329 --------- applications/Chat/evaluate/unieval/scorer.py | 96 --- applications/Chat/evaluate/unieval/utils.py | 285 -------- applications/Chat/evaluate/utils.py | 206 ------ applications/ColossalEval/README.md | 554 ++++++++++++++++ .../ColossalEval/colossal_eval/__init__.py | 0 .../colossal_eval/dataset/__init__.py | 19 + .../colossal_eval/dataset/agieval.py | 247 +++++++ .../colossal_eval/dataset/base.py | 24 + .../colossal_eval/dataset/ceval.py | 132 ++++ .../colossal_eval/dataset/cmmlu.py | 144 ++++ .../colossal_eval/dataset/colossalai.py | 70 ++ .../colossal_eval/dataset/gaokaobench.py | 122 ++++ .../colossal_eval/dataset/longbench.py | 120 ++++ .../colossal_eval/dataset/mmlu.py | 73 ++ .../colossal_eval/evaluate/GPT Evaluation.md | 248 +++++++ .../colossal_eval/evaluate/__init__.py | 0 .../evaluate/dataset_evaluator/__init__.py | 3 + .../dataset_evaluator/dataset_evaluator.py | 269 ++++++++ .../evaluate/dataset_evaluator/metrics.py | 623 ++++++++++++++++++ .../colossal_eval/evaluate/evaluator.py | 110 ++++ .../colossal_eval}/evaluate/gpt_evaluate.py | 90 ++- .../colossal_eval/evaluate/utils.py | 8 + .../colossal_eval/models/__init__.py | 5 + .../ColossalEval/colossal_eval/models/base.py | 78 +++ .../colossal_eval/models/chatglm.py | 303 +++++++++ .../colossal_eval/models/huggingface.py | 561 ++++++++++++++++ .../colossal_eval/utils/__init__.py | 4 + .../colossal_eval/utils/conversation.py | 231 +++++++ .../colossal_eval/utils/utilities.py | 62 ++ .../gpt_evaluation/config/config_cn.json | 44 ++ .../gpt_evaluation/config/config_en.json | 44 ++ .../gpt_evaluation/data/eval_cn_examples.json | 202 ++++++ .../gpt_evaluation/data/eval_en_examples.json | 202 ++++++ .../battle_prompt/battle_prompt_cn.json | 0 .../battle_prompt/battle_prompt_en.json | 0 .../evaluation_prompt_cn.json | 91 +-- .../evaluation_prompt_en.json | 92 +-- .../config/evaluation/config.json | 58 ++ .../config/inference/config.json | 84 +++ .../dataset_evaluation/eval_dataset.py | 73 ++ .../dataset_evaluation/eval_dataset.sh | 4 + .../examples/dataset_evaluation/inference.py | 171 +++++ .../examples/dataset_evaluation/inference.sh | 4 + .../config/evaluation/config.json | 44 ++ .../config/inference/config.json | 33 + .../examples/gpt_evaluation}/eval.py | 37 +- .../examples/gpt_evaluation}/eval.sh | 0 .../examples/gpt_evaluation/inference.py | 171 +++++ .../examples/gpt_evaluation/inference.sh | 4 + applications/ColossalEval/requirements.txt | 12 + applications/ColossalEval/setup.py | 31 + applications/README.md | 1 + 60 files changed, 5314 insertions(+), 2497 deletions(-) delete mode 100644 applications/Chat/evaluate/README.md delete mode 100644 applications/Chat/evaluate/config/config_cn.json delete mode 100644 applications/Chat/evaluate/config/config_en.json delete mode 100644 applications/Chat/evaluate/evaluator.py delete mode 100644 applications/Chat/evaluate/metrics.py delete mode 100644 applications/Chat/evaluate/requirements.txt delete mode 100644 applications/Chat/evaluate/unieval/__init__.py delete mode 100644 applications/Chat/evaluate/unieval/evaluator.py delete mode 100644 applications/Chat/evaluate/unieval/scorer.py delete mode 100644 applications/Chat/evaluate/unieval/utils.py delete mode 100644 applications/Chat/evaluate/utils.py create mode 100644 applications/ColossalEval/README.md create mode 100644 applications/ColossalEval/colossal_eval/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/agieval.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/base.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/ceval.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/cmmlu.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/colossalai.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/gaokaobench.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/longbench.py create mode 100644 applications/ColossalEval/colossal_eval/dataset/mmlu.py create mode 100644 applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md create mode 100644 applications/ColossalEval/colossal_eval/evaluate/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py create mode 100644 applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py create mode 100644 applications/ColossalEval/colossal_eval/evaluate/evaluator.py rename applications/{Chat => ColossalEval/colossal_eval}/evaluate/gpt_evaluate.py (89%) create mode 100644 applications/ColossalEval/colossal_eval/evaluate/utils.py create mode 100644 applications/ColossalEval/colossal_eval/models/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/models/base.py create mode 100644 applications/ColossalEval/colossal_eval/models/chatglm.py create mode 100644 applications/ColossalEval/colossal_eval/models/huggingface.py create mode 100644 applications/ColossalEval/colossal_eval/utils/__init__.py create mode 100644 applications/ColossalEval/colossal_eval/utils/conversation.py create mode 100644 applications/ColossalEval/colossal_eval/utils/utilities.py create mode 100644 applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json create mode 100644 applications/ColossalEval/configs/gpt_evaluation/config/config_en.json create mode 100644 applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json create mode 100644 applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json rename applications/{Chat/evaluate => ColossalEval/configs/gpt_evaluation}/prompt/battle_prompt/battle_prompt_cn.json (100%) rename applications/{Chat/evaluate => ColossalEval/configs/gpt_evaluation}/prompt/battle_prompt/battle_prompt_en.json (100%) rename applications/{Chat/evaluate => ColossalEval/configs/gpt_evaluation}/prompt/evaluation_prompt/evaluation_prompt_cn.json (56%) rename applications/{Chat/evaluate => ColossalEval/configs/gpt_evaluation}/prompt/evaluation_prompt/evaluation_prompt_en.json (59%) create mode 100644 applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json create mode 100644 applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json create mode 100644 applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py create mode 100644 applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh create mode 100644 applications/ColossalEval/examples/dataset_evaluation/inference.py create mode 100644 applications/ColossalEval/examples/dataset_evaluation/inference.sh create mode 100644 applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json create mode 100644 applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json rename applications/{Chat/evaluate => ColossalEval/examples/gpt_evaluation}/eval.py (78%) rename applications/{Chat/evaluate => ColossalEval/examples/gpt_evaluation}/eval.sh (100%) mode change 100755 => 100644 create mode 100644 applications/ColossalEval/examples/gpt_evaluation/inference.py create mode 100644 applications/ColossalEval/examples/gpt_evaluation/inference.sh create mode 100644 applications/ColossalEval/requirements.txt create mode 100644 applications/ColossalEval/setup.py diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md deleted file mode 100644 index 0a97ae72f9d0..000000000000 --- a/applications/Chat/evaluate/README.md +++ /dev/null @@ -1,396 +0,0 @@ -# Evaluation - -In this directory, we introduce how you can evaluate your model with our pipeline. This pipeline is now available for evaluation of both Chinese and English capability. - -## Installation - -To start model evaluation, you need to install required packages which listed in `requirements.txt` under `evaluate` folder. - -```shell -pip install -r requirements.txt -``` - -## Evaluation Pipeline - -The whole evaluation pipeline consists of three methods: - -1. `GPT Evaluation`: evaluates model predictions using GPT models. - - Compare the performance of two different models (battle). - - Rate the model according to pre-defined metrics using prompting design. - - Rate the model according to pre-defined metrics with additional reference answer using prompting design. -2. `Automatic Evaluation`: evaluates model predictions using automatic metrics. -3. `UniEval`: evaluates model predictions using UniEval models(English only). - -### Evaluation Category - -Our evaluation pipeline examines the model's capability using 10 categories of questions. The following table introduces each category: - -| Evaluation Category | Description | -| :-----------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. | -| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. | -| Classification | Models are asked to do classification tasks. The capability of accurate classification is required. | -| Closed QA | Models are asked to answer a closed QA question. The capability of answering questions with limited scope (such as single/multiple choice question) is required. | -| Extraction | Models are asked to extract information from a given material. The capability of extracting required information is required. | -| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. | -| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. | -| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. | -| Rewriting | Models are asked to do rewriting tasks such as translation and grammar correction. The capability of rewriting according to different instructions is required. | -| Summarization | Models are asked to summarize the given paragraph or passage. The capability of summarization is required. | - -To better understand each evaluation category, here are some example questions provided. - -| Evaluation Category | Chinese Example | English Example | -| :-----------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Brainstorming | **Example 1:**
请介绍一下人工智能的多个领域。

**Example 2:**
请给出管理家庭财务的 3 个小技巧。
| **Example 1:**
How can I improve my memory? Any useful techniques you can suggest?

**Example 2:**
What are some ways to increase productivity while working from home? | -| Chat | **Example 1:**
基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
老李:你好,小张,我很乐意帮助你。你想问些什么?
小张:我想知道如何确定鸡的品种和性别?
老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
小张:

**Example 2:**
基于以下角色信息完成一段对话。小明是一名医生,一位老年病患者想要停药,但他对病情有所忽视并有担忧;王叔叔是老年病患者的儿子,希望能够听取医生的建议。
小明:你好,王叔叔,我了解你想要让你父亲停药。
王叔叔:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。
小明: | **Example 1:**
Complete a conversation based on the following character information. Amy is a 30-year-old chef who runs her own restaurant. Jack is a food blogger who specializes in reviewing local restaurants.
Amy: Hi Jack, I heard that you're a food blogger. Nice to meet you.
Jack: Hi Amy, yes I am. Your restaurant has been receiving a lot of good reviews lately.
Amy: Yes, we use only fresh and quality ingredients, and every dish is carefully crafted.
Jack:

**Example 2:**
Complete a dialogue based on the following role information. A: Elementary student B: Teacher
B: Good morning, Student A. Today we're going to learn about addition and subtraction.
A: Teacher, I already know this very well. Why do I need to learn it again?
B: | -| Classification | **Example 1:**
新闻标题:今日立夏,有一上联,立夏万物并秀,下联怎么对?
请根据以上新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。

**Example 2:**
新闻标题:赵丽颖很久没有登上微博热搜了,但你们别急,她只是在憋大招而已。
请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | **Example 1:**
Title: Fighting for Love (2020)
Description: Jasmine got obsessed with a man and now he's obsessed with her. Steamy nights, kisses and rules being broken awaits them. She turned his whole world upside down and now he's doing it to hers. In this free fall, can they survive each others love?\"
Based on the above information, determine which genre the work of art belongs to. You can only choose one from \"sport\", \"horror\", \"drama\", \"history\", \"romance\", \"biography\", \"science fiction\", \"comedy\", \"animation\", \"documentary\", \"music\" and \"news\".

**Example2:**
Title: Summer Breeze: The Isley Brothers Greatest Hits Live (2005)
Description: Filmed in the US in 2005 and captured in excellent form led by Ron Isley's vocals and Ernie Isley's hard edged guitar. Virtually every track is a hit including Shout, Who's That Lady, Twist And Shout, Summer Breeze and Harvest For The World.
Based on the above information, determine which genre the work of art belongs to. You can only choose one from \"sport\", \"horror\", \"drama\", \"history\", \"romance\", \"biography\", \"science fiction\", \"comedy\", \"animation\", \"documentary\", \"music\" and \"news\"." | -| Closed QA | **Example 1:**
请从以下选项中选择正确答案。以下哪个是世界上最高山峰?
A. 长城
B. 泰山
C. 珠穆朗玛峰
D. 黄山

**Example 2:**
请从以下选项中选择一个最佳答案回答下面的问题。问题:非洲最高的山是哪座山?
选项:
A. 麦金利山
B. 喜马拉雅山
C. 乞力马扎罗山 | **Example 1:**
Which of the following options is NOT a primary color?
(a) yellow
(b) blue
(c) orange
(d) red

**Example 2:**
Choose the correct option to complete the following sentence: \"Harry Potter and the Chamber of Secrets\" is the **\_\_\_\_** book in the Harry Potter series.
(A) first
(B) second
(C) third
(D) fourth | -| Extraction | **Example 1:**
根据以下新闻文本,提取新闻报道时间,例如回答时按照格式“新闻报道时间:2007 年 8 月 10 日”
新闻文本如下:2007-4-7 中新网 4 月 7 日电据中国消防在线消息,4 月 4 日晚上 7 时 30 分左右,湖南长潭高速公路上发生一起 6 车连环相撞失火事故。长株潭三地消防部门共出动消防车 21 台,警力 100 余人。经过消防官兵近 2 个小时奋力扑救,大火被成功扑灭。据初步调查,有 1 人在此次事故中死亡。

**Example 2:**
根据以下新闻文本,提取新闻报道时间,例如回答时按照格式“新闻报道时间:2007 年 8 月 10 日”
新闻文本如下:2014 年 1 月 15 日,据外媒《俄罗斯报》报道称,位于北半球的澳大利亚现在正处于炎热的夏季,而近日也到了高温酷暑的时候,当地时间 1 月 14 日晚,澳大利亚南部一夜间发生至少 250 起火灾。受炎热天气及雷雨天气影响,澳大利亚南部一夜间发生至少 250 起火灾,灾情多集中在维多利亚州。火灾发生后,救援人员立即展开救灾行动。目前,大部分起火点火势已被控制。 | **Example 1:**
Ernest Hemingway, an American literary giant known for his spare and direct writing style, has penned timeless works such as 'The Old Man and the Sea', 'For Whom the Bell Tolls', and 'A Farewell to Arms', which have made a profound impact on the literary world and continue to be widely read and admired today.
Extract the name of the author mentioned above.

**Example 2:**
In the epic fantasy series 'A Song of Ice and Fire', George R.R. Martin weaves a complex web of political intrigue, war, and magic across the fictional continents of Westeros and Essos. Martin's richly developed characters and intricate plotlines have captivated readers worldwide, much like his other acclaimed works such as 'A Clash of Kings' and 'A Storm of Swords'.
Extract the name of the author in the above material. | -| Generation | **Example 1:**
请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。

**Example 2:**
请根据以下情节撰写一篇短篇小说:一名年轻人被困在一个荒岛上,他必须想办法生存下去直到被救援。但他很快发现自己并不孤单。 | **Example 1:**
Write a descriptive paragraph about an island to relax and unwind, including details about the location and atmosphere.

**Example 2:**
Can you help me write a persuasive email to my colleagues encouraging them to participate in a charitable fundraising event? | -| Open QA | **Example 1:**
请问万有引力定律由谁提出的?

**Example 2:**
哪些国家参与了第一次世界大战? | **Example 1:**
What are the four basic tastes of the human palate?

**Example 2:**
Who painted the The Scream? | -| Rewriting | **Example 1:**
请将以下句子改为正确的语序。
生日快乐你祝他了吗?

**Example 2:**
将以下文本翻译成英语:
“这个周末我要去海边玩” | **Example 1:**
Please translate the following sentences, which are a mixture of Chinese and English, into full English.
我需要买一些 healthy snacks,比如 nuts 和 dried fruits,作为我的 office 的午餐.

**Example 2:**
Please rewrite the sentence using an inverted sentence structure.
We won't begin our journey until the sun sets. | -| Roleplay | **Example 1:**
我想让你担任 Android 开发工程师面试官。我将成为候选人,您将向我询问 Android 开发工程师职位的面试问题。我希望你只作为面试官回答。不要一次写出所有的问题。我希望你只对我进行采访。问我问题,等待我的回答。不要写解释。像面试官一样一个一个问我,等我回答。我的第一句话是“面试官你好”。

**Example 2:**
我想让你扮演讲故事的角色。你会想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的有潜力的故事以吸引人们的注意力和想象力。根据目标受众,您可以为您的讲故事环节选择特定的主题或主题,例如,如果是儿童,那么您可以谈论动物;如果是成人,那么基于历史的故事可能会更好地吸引他们等。我的第一个请求是我需要一个关于毅力的有趣故事。 | **Example 1:**
Assume the role of a marriage counselor. Develop a series of communication exercises for a couple who are experiencing difficulties in their relationship. These exercises should promote active listening, empathy, and effective expression of emotions. Your first assignment is to provide a set of three exercises that focus on resolving conflicts and rebuilding trust.

**Example 2:**
I want you to act as a travel agent. I will tell you my desired destination, travel dates, and budget, and it will be your job to suggest the best travel itinerary for me. Your recommendations should include the best transportation options, hotel accommodations, and any popular tourist attractions nearby. My first request is "I want to plan a trip to Tokyo for a week, with a budget of $2000. I want to explore the culture and food of the city." | -| Summarization | **Example 1:**
请简要总结概括以下段落材料。
当地时间 29 日,泰国卫生部通报,新增 143 名新冠肺炎确诊病例和 1 名死亡病例。截止到当地时间 29 日上午,泰国累计确诊病例 1388 例,其中泰国籍 1172 例,非泰国籍 216 例。死亡病例累计 7 例。(原题为《泰国新增 143 例新冠肺炎确诊病例累计确诊 1388 例》)

**Example 2:**
请简要总结概括以下段落材料。
近期,参与京雄高铁站站房建设的中铁十二局,因在施工过程中存在环境违法行为被雄安新区公开通报。通报发出后,引起社会广泛关注。近日,人民网记者从雄安新区相关部门及中铁十二局获悉,新区有关部门已经集中约谈了中铁十二局等 24 个参与雄安建设的项目单位。对于约谈内容和结果,中铁十二局有关宣传负责人回应:“具体内容不清楚,最好找雄安新区相关部门了解情况。”新区有关部门负责人表示,此前涉及的环境违法行为,中铁十二局已基本整改到位,但约谈内容和结果暂不公开,接下来,将按部就班推进环境治理工作。(原题为《雄安新区:中铁十二局涉环境违法已基本整改到位》) | **Example 1:**
The 21 year-old-woman was treated by paramedics after the kitchen fire in Botfield Road in Shifnal, Shropshire. West Mercia Police said it is treating Wednesday morning's incident as arson and are appealing for any witnesses to contact them.The 50-year-old man has been arrested on suspicion of arson with intent to endanger life. For more on this and other stories from Shropshire.
Please briefly summarize the above material within 20 words.

**Example 2:**
South Wales Police were called to a property in Heolgerrig, Merthyr Tydfil, at about 13:40 BST on Sunday. The child was airlifted to Prince Charles Hospital but died shortly afterwards. Police are investigating the circumstances surrounding the incident and have appealed for witnesses. The girl's family are being supported by specially trained officers.
Please briefly summarize the above material within 20 words. | - -### Evaluation Metrics - -#### GPT Evaluation - -GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 11 pre-defined evaluation metrics both in Chinese and English: - -| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) | -| :----------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| 语言组织
(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。

Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
6. 根据以上因素综合评估答案的语言组织,并给出一个 1 到 5 的分数,其中 5 表示语言组织非常好,而 1 表示语言组织非常差。

1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
3. Determine if the answer is relevant to the question or topic and conveys a clear message.
4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. | -| 切题
(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。

Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
2. 阅读答案,确认答案是否直接回答了题目所问的问题。
3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
4. 根据以上因素综合评估答案的切题程度,并给出一个 1 到 5 的分数,其中 5 表示答案非常切题,而 1 表示答案完全没有切题。

1. Read the question to determine what the question asks and what aspects of the question need to be answered.
2. Read the answers to make sure that they directly answer the question asked.
3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. | -| 创意性
(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。

Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
4. 根据答案的创意性,给出一个 1 到 5 的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。

1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. | -| 实用性
(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。

Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
4. 根据答案的实用性,给出一个 1 到 5 的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。

1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. | -| 正确性
(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。

Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为 5 分。如果答案是部分正确的,则可以给予适当的得分,例如 2 分、3 分或 4 分。如果答案完全不正确,则只得 1 分。

1. Read the question carefully and try to answer the question yourself.
2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. | -| 自然
(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。

Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
2. 检查答案内容是否符合题目给定的身份。
3. 根据以上因素,对该回答的自然性进行打分,分数从 1 到 5,其中 1 表示不自然,5 表示非常自然,并符合问题给定的身份。

1. Read the question and determine the identity information provided in the question.
2. Check whether the content of the answer matches the identity given in the question.
3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. | -| 参与感
(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。

Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
3. 根据以上因素,对该回答的参与感进行打分,分数从 1 到 5,其中 1 表示没有参与感,5 表示非常有参与感,并且恰当地理解了对话的语境和背景。

1. Read the questions to determine the context and background of the dialogue.
2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. | -| 合理性
(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。

Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
3. 根据以上因素,对该回答的合理性进行打分,分数从 1 到 5,其中 1 表示不合理,5 表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。

1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. | -| 多样性
(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。

Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在 1 到 5 之间,5 分表示回答的质量很好,能够吸引人阅读,1 分表示回答的内容生硬或者有离题的问题。

1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
3. Check the creativity and imagination of the response to see if the response is engaging to read on.
4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. | -| 保真度
(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。

Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
阅读题目的请求,确认回答请求时需要注意的细节。
3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
4. 结合以上评估结果给出保真度的评分,范围从 1 到 5 分,其中 1 分表示回答与角色设定完全不符,5 分表示回答完全符合角色设定且满足给定请求。

1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
2. Read the question's request and confirm the details that need to be taken into account when answering the request.
3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. | -| 简明扼要
(Conciseness) | 简明扼要(1-5):答案是否简明扼要,没有冗余内容。

Conciseness (1-5): answers should be concise and without redundant content. | 1. 阅读题目,提取出材料的重点。
2. 阅读该总结,并注意其中的主要观点和信息。
3. 评估总结的长度。一个简明扼要的总结通常应该在几句话或几段文字内传达关键信息,而不是冗长的段落或文章。
4. 检查总结是否包含与主要观点无关的信息或冗余信息。
5. 确定总结涵盖了材料中的关键信息,并且没有忽略任何重要细节。
6. 给总结打出 1-5 的分数,其中 5 表示总结简明扼要,没有冗余内容,而 1 表示总结冗长或包含不必要的信息,难以理解或记忆。根据您的判断,打出适当的得分。

1. Read the title and extract the main points of the material.
2. Read the summary and note the main ideas and messages in it.
3. Assess the length of the summary. A concise summary should usually convey key information within a few sentences or paragraphs, rather than lengthy paragraphs or essays.
4. Check that the summary does not contain information that is not relevant to the main ideas or that is redundant.
5. Make sure that the summary covers the key information in the material and that no important details have been omitted.
6. Rate the summary on a scale of 1-5, where 5 means the summary is concise and free of redundancy, and 1 means the summary is lengthy or contains unnecessary information that is difficult to understand or remember. Based on your judgment, assign the appropriate score. | - -GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5. - -> **NOTE 1:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "Whether the answer is correct or not."(this is for category `classification`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`. - -> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq). - -#### Automatic Evaluation - -Automated metrics evaluate the capability of a model by comparing model predictions with reference answers. -There are two ways to obtain reference answers: - -- For instruction coming from human-designed problems, the reference answers are generated by GPT-3.5, such as roleplay, chat. -- For instruction related with classic NLP problems, the reference answers are collected from open-sourced dataset with target answers, such as classification, extraction, summarization. - -There are 6 types of automatic evaluation metrics listed in the table below: - -| Automatic Evaluation Metric | Description | -| :---------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| BLEU-n | Measure the accuracy between prediction and reference.
BLEU-1 (Unigram) evaluates accuracy in word level.
BLEU-n (n-gram) evaluate the fluency in sentence level. | -| ROUGE | ROUGE-N measures the number of matching n-grams between prediction and reference.
ROUGE-L measures the number of matching longest common subsequence (LCS) between prediction and reference. | -| Distinct | Measure the diversity of generation text by counting the unique n-grams. | -| BERTScore | Measure the semantic similarity between tokens of predictions and references with BERT. | -| Precision
Recall
F1 Score | Measure the number of overlaps between prediction and reference (design for classification and extraction categories). | -| CHRF | Measure the similarity of character n-grams between prediction and reference. | - -#### UniEval Evaluation - -UniEval converts all evaluation tasks of different dimensions(metrics) into Boolean QA problems and utilize the model to answer with “Yes” or “No”. Compared with similarity-based metrics such as ROUGE and BLEU, UniEval can achieve a more comprehensive evaluation. In addition, UniEval also demonstrates its ability to transfer to unseen dimensions and tasks. - -In our evaluation pipeline, two pre-trained UniEval evaluators are used. One is [unieval-sum](https://huggingface.co/MingZhong/unieval-sum) and the other is [unieval-dialog](https://huggingface.co/MingZhong/unieval-dialog). The two models can be used for the 3 tasks, `summarization`, `dialogue` and `data2text`. Each task has different evaluation dimensions. - -| UniEval Model | Task | Dimension(Metric) | -| :------------: | :------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| unieval-sum | summarization | coherence: whether the summary is coherent
consistency: whether the claim is consistent with the given document
fluency: whether the paragraph is fluent
relevance: whether the summary is relevant to the reference | -| unieval-sum | data2text | naturalness: whether the utterance is fluent
informativeness: whether the utterance is informative according to the reference | -| unieval-dialog | dialogue | naturalness: whether the response is natural in the dialogue
coherence: whether the response is coherent in the dialogue history
understandability: whether the response is understandable in the dialogue | - -> **NOTE 1:** Task "data2text" uses the same model as task "summarization". - -> **NOTE 2:** In UniEval paper, the `unieval-sum` model demonstrates the best transfer ability and so you can evaluate your customized metric with this model. Details of adding customized metrics can be found in [FAQ](#faq). - -> **NOTE 3:** We consider not including all metrics provided in UniEval in our pipeline because the data structure and content of the instructions we want to evaluate are not suitable for direct use of some UniEval metrics. - -## Evaluation Process - -### Data Format - -#### Target Answers / Predictions - -A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question. -An element should have the following fields: - -- `category` (str, compulsory): The category of the instruction / question. -- `instruction` (str, compulsory): The instruction / question for the LLM. -- `input` (str, optional): The additional context of the instruction / question. -- `output` (str, optional): The sample output of the instruction (default: GPT-3.5). -- `target` (str, optional): The target answer for the instruction. -- `id` (int, compulsory): The ID of the instruction / question. - -If the `input` has a target answer, the `output` can be empty. Otherwise, we generate answers from GPT-3.5 as the `output`, and the `target` field is empty. - -Example: - -```json -[ - { - "category": "brainstorming", - "instruction": "请介绍一下人工智能的多个领域。", - "input": "", - "output": "{GPT-3.5 Answers}", - "target": "", - "id": 1 - }, - { - "category": "classification", - "instruction": "新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。", - "input": "", - "output": "", - "target": "{target answer}", - "id": 2 - } -] -``` - -#### Model Answers / Predictions - -A JSON file contains one list. Each element in the list is a model answer / prediction record for one instruction / question. - -An element should have the following fields: - -- `category` (str, compulsory): The category of the instruction / question. -- `instruction` (str, compulsory): The instruction / question for the LLM. -- `input` (str, optional): The additional context of the instruction / question. -- `output` (str, compulsory): The output from the LLM. -- `target` (str, optional): The target answer for the instruction. -- `id` (int, compulsory): The ID of the instruction / question. - -Example: - -```json -[ - { - "category": "brainstorming", - "instruction": "请介绍一下人工智能的多个领域。", - "input": "", - "output": "{Model Answers / Predictions}", - "target": "", - "id": 1 - }, - { - "category": "classification", - "instruction": "新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。", - "input": "", - "output": "{Model Answers / Predictions}", - "target": "{target answer}", - "id": 2 - } -] -``` - -### Prompt - -#### Battle Prompt - -The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `prompt/battle_prompt`. - -```json -{ - "id": 1, - "system_prompt": "你是一个检查回答质量的好助手。", - "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n", - "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。" -} -``` - -#### Evaluation Prompt - -The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `prompt/evaluation_prompt`. - -```json -{ - "brainstorming": { - "id": 1, - "category": "brainstorming", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:" - }, - "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - } -} -``` - -`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file. - -`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`. - -### Evaluation - -#### Configuration - -The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics, automatic metrics and UniEval metrics in key `GPT`, `Metrics` and `UniEval`(English only). You can find an example English config file in `config`. - -```json -{ - "language": "en", - "path_for_UniEval": { - "summarization": "path to unieval-sum model", - "dialogue": "path to unieval-dialog model", - "data2text": "path to unieval-sum model" - }, - "category": { - "brainstorming": { - "GPT": ["relevance", "creativity", "practicality", "reasonableness"], - "Metrics": ["Distinct"], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "chat": { - "GPT": ["relevance", "naturalness", "engagingness", "reasonableness"], - "Metrics": ["Distinct"], - "UniEval": [ - "dialogue-naturalness", - "dialogue-coherence", - "dialogue-understandability" - ] - } - } -} -``` - -`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now. - -`"path_for_UniEval"`: path to the UniEval model. - -`"category"`: the category/categories needed to evaluate the model capability. - -`"GPT"`: the metrics you want to use for GPT evaluation. - -`"Metrics"`: the metrics you want to use for automatic metrics evaluation. - -`"UniEval"`: the metrics you want to use for UniEval metrics evaluation. The metric has to be in the `"{task}-{metric}"` format because different tasks have same metrics such as naturalness and coherence. - -You can remove the key such as `"Metrics"` to skip evaluating answers using its corresponding evaluation metrics. - -You can create your config file based on available settings listed in following table. - -| "category" | "GPT" | "Metrics" | "UniEval" | -| :--------------: | :---------------------: | :---------: | :--------------------------: | -| "brainstorming" | "language organization" | "BLEU" | "dialogue-naturalness" | -| "chat" | "relevance" | "ROUGE" | "dialogue-coherence" | -| "classification" | "creativity" | "Distinct" | "dialogue-understandability" | -| "closed_qa" | "practicality" | "BERTScore" | "data2text-naturalness" | -| "extraction" | "correctness" | "Precision" | "data2text-informativeness" | -| "generation" | "naturalness" | "Recall" | "summarization-coherence" | -| "open_qa" | "engagingness" | "F1 score" | "summarization-consistency" | -| "rewriting" | "reasonableness" | "CHRF" | "summarization-fluency" | -| "roleplay" | "diversity" | | "summarization-relevance" | -| "summarization" | "fidelity" | | | -| | "conciseness" | | | - -> **NOTE:** For categories which don't have standard answers such as `brainstorming`, you should avoid using automatic metrics such as `BLEU` and `ROUGE` which are based on similarity measures and you should use `Distinct` instead in your config file. - -#### Evaluate - -After setting the configuration file, you can evaluate the model using `eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models. - -An example script is provided as follows: - -```shell -python eval.py \ - --config_file "path to the config file" \ - --battle_prompt_file "path to the prompt file for battle" \ - --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ - --target_file "path to the target answer file" \ - --answer_file_list "path to the answer files of at most 2 models" \ - --model_name_list "the names of at most 2 models" \ - --gpt_model "which GPT model to use for evaluation" \ - --save_path "path to save results" \ - --openai_key "your openai key" \ -``` - -If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`. - -## FAQ - -
How can I add a new GPT evaluation metric? - -For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric. - -```json -{ - "brainstorming": { - "id": 1, - "category": "brainstorming", - "metrics": { - "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness" - }, - "CoT": { - "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:" - }, - "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - } -} -``` - -
- -
How can I add a new UniEval evaluation metric? - -For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown, and you may need some experiments to test whether the model is capable of evaluating this metric. - -```python -if task == 'data2text': - if dimension == 'persuasiveness': - cur_input = 'question: Is this a persuasive utterence utterance: ' + output[i] -``` - -
- -## To Do - -- [x] Add evaluation for English capability -- [x] Support UniEval -- [x] Support GPT-4 evaluation -- [x] Support GPT evaluation with reference - -## Citations - -```bibtex -@misc{vicuna2023, - title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality}, - url = {https://vicuna.lmsys.org}, - author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.}, - month = {March}, - year = {2023} -} - -@misc{liu2023geval, - title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment}, - author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu}, - year={2023}, - eprint={2303.16634}, - archivePrefix={arXiv}, - primaryClass={cs.CL} -} - -@misc{zhong2022unified, - title={Towards a Unified Multi-Dimensional Evaluator for Text Generation}, - author={Ming Zhong and Yang Liu and Da Yin and Yuning Mao and Yizhu Jiao and Pengfei Liu and Chenguang Zhu and Heng Ji and Jiawei Han}, - year={2022}, - eprint={2210.07197}, - archivePrefix={arXiv}, - primaryClass={cs.CL} -} -``` diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json deleted file mode 100644 index 4d30d005df30..000000000000 --- a/applications/Chat/evaluate/config/config_cn.json +++ /dev/null @@ -1,204 +0,0 @@ -{ - "language": "cn", - "category": { - "brainstorming": { - "GPT": [ - "language organization", - "relevance", - "creativity", - "practicality", - "reasonableness" - ], - "Metrics": [ - "Distinct" - ] - }, - "chat": { - "GPT": [ - "language organization", - "naturalness", - "engagingness", - "fidelity" - ], - "Metrics": [ - "Distinct" - ] - }, - "classification": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "Precision", - "Recall", - "F1 score", - "CHRF" - ] - }, - "closed_qa": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore", - "CHRF" - ] - }, - "extraction": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "Precision", - "Recall", - "F1 score", - "CHRF" - ] - }, - "generation": { - "GPT": [ - "language organization", - "relevance", - "diversity" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore" - ] - }, - "logical_reasoning": { - "GPT": [ - "correctness", - "relevance", - "reasonableness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore", - "CHRF" - ] - }, - "open_qa": { - "GPT": [ - "language organization", - "relevance", - "correctness" - ], - "Metrics": [ - "Distinct" - ] - }, - "rewriting": { - "GPT": [ - "language organization", - "relevance", - "correctness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore" - ] - }, - "roleplay": { - "GPT": [ - "language organization", - "relevance", - "fidelity", - "creativity" - ], - "Metrics": [ - "Distinct" - ] - }, - "summarization": { - "GPT": [ - "language organization", - "relevance", - "correctness", - "conciseness" - ], - "Metrics": [ - ] - }, - "Finance": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "Law": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "Education": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "Medical": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "STEM": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "SocialScience": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "Humanity": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "Other": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - }, - "ethics": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ] - } - } -} diff --git a/applications/Chat/evaluate/config/config_en.json b/applications/Chat/evaluate/config/config_en.json deleted file mode 100644 index c964122dd6d6..000000000000 --- a/applications/Chat/evaluate/config/config_en.json +++ /dev/null @@ -1,283 +0,0 @@ -{ - "language": "en", - "path_for_UniEval": { - "summarization": "path to unieval-sum", - "dialogue": "path to unieval-dialog", - "data2text": "path to unieval-sum" - }, - "category": { - "brainstorming": { - "GPT": [ - "language organization", - "relevance", - "creativity", - "practicality", - "reasonableness" - ], - "Metrics": [ - "Distinct" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "chat": { - "GPT": [ - "language organization", - "naturalness", - "engagingness", - "fidelity" - ], - "Metrics": [ - "Distinct" - ], - "UniEval": [ - "summarization-fluency", - "dialogue-naturalness", - "dialogue-coherence", - "dialogue-understandability", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "classification": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "Precision", - "Recall", - "F1 score", - "CHRF" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "closed_qa": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore", - "CHRF" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "extraction": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - "Precision", - "Recall", - "F1 score", - "CHRF" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "generation": { - "GPT": [ - "language organization", - "relevance", - "diversity" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "logical_reasoning": { - "GPT": [ - "correctness", - "relevance", - "reasonableness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore", - "CHRF" - ], - "UniEval": [ - ] - }, - "open_qa": { - "GPT": [ - "language organization", - "relevance", - "correctness" - ], - "Metrics": [ - "Distinct" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "rewriting": { - "GPT": [ - "language organization", - "relevance", - "correctness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "roleplay": { - "GPT": [ - "language organization", - "relevance", - "fidelity", - "creativity" - ], - "Metrics": [ - "Distinct" - ], - "UniEval": [ - "summarization-fluency", - "data2text-naturalness", - "data2text-informativeness" - ] - }, - "summarization": { - "GPT": [ - "language organization", - "relevance", - "correctness", - "conciseness" - ], - "Metrics": [ - "BLEU", - "ROUGE", - "BERTScore", - "CHRF" - ], - "UniEval": [ - ] - }, - "Finance": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "Law": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "Education": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "Medical": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "STEM": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "SocialScience": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "Humanity": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "Other": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - }, - "ethics": { - "GPT": [ - "relevance", - "correctness" - ], - "Metrics": [ - ], - "UniEval": [ - ] - } - } -} diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py deleted file mode 100644 index 1d998cd2d09c..000000000000 --- a/applications/Chat/evaluate/evaluator.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -from typing import Any, Dict, List - -import gpt_evaluate -import metrics -import unieval -from utils import analyze_automatic_results, get_data_per_category, save_automatic_results - - -class Evaluator(object): - """ - A class named Evaluator includes GPT-3.5/GPT-4 evaluation - and automatic evaluation - - """ - - def __init__( - self, - params: Dict[str, Any], - battle_prompt: Dict[str, Any], - gpt_evaluation_prompt: Dict[str, Any], - gpt_model: str, - language: str, - path_for_UniEval: Dict[str, str], - gpt_with_reference: bool, - ) -> None: - self.params = params - self.battle_prompt = battle_prompt - self.gpt_evaluation_prompt = gpt_evaluation_prompt - self.gpt_model = gpt_model - self.language = language - self.path_for_UniEval = path_for_UniEval - self.gpt_with_reference = gpt_with_reference - self.automatic_metric_stats = dict() - self.unieval_metric_stats = dict() - self.gpt_evaluation_results = dict() - self.battle_results = [] - - def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None: - """ - Comparison between two models using GPT-4 as the reviewer. - """ - - self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt) - - def evaluate(self, answers: List[Dict], targets: List[Dict]) -> None: - """ - A comprehensive evaluation of the answers from the model. - The function evaluates the model's performance from different perspectives - using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. - - The metrics will be decided by the config file. - - """ - - def switch(metric, language): - if metric == "BLEU": - return metrics.bleu_score(preds=predicts_list, targets=targets_list, language=language) - elif metric == "ROUGE": - return metrics.rouge_score(preds=predicts_list, targets=targets_list, language=language) - elif metric == "Distinct": - return metrics.distinct_score(preds=predicts_list, language=language) - elif metric == "BERTScore": - return metrics.bert_score(preds=predicts_list, targets=targets_list, language=language) - elif metric == "Precision": - return metrics.precision(preds=predicts_list, targets=targets_list, language=language) - elif metric == "Recall": - return metrics.recall(preds=predicts_list, targets=targets_list, language=language) - elif metric == "F1 score": - return metrics.F1_score(preds=predicts_list, targets=targets_list, language=language) - elif metric == "CHRF": - return metrics.chrf_score(preds=predicts_list, targets=targets_list, language=language) - else: - raise ValueError(f"Unexpected metric") - - answers_per_category = get_data_per_category(answers, list(self.params.keys())) - targets_per_category = get_data_per_category(targets, list(self.params.keys())) - - # automatic evaluation - for category in self.params: - if len(answers_per_category[category]) == 0: - print(f"Category {category} specified in your config doesn't have corresponding answers!") - continue - - if self.params[category].get("Metrics", None) is None: - continue - - category_metrics = self.params[category]["Metrics"] - self.automatic_metric_stats[category] = {} - - targets_list = [ - target["target"] if target["target"] else target["output"] for target in targets_per_category[category] - ] - predicts_list = [answer["output"] for answer in answers_per_category[category]] - - for metric in category_metrics: - self.automatic_metric_stats[category].update(switch(metric=metric, language=self.language)) - - # UniEval evaluation - # self.unieval_metric_stats's key is "task" instead of "category". - # Iterating "task" first will avoid repeated loading models because one task corresponds to one UniEval model. - # If key is "category", different models will be loaded for multiple times across categories because the user may require different task(models) to evaluate one category. - for category in self.params: - if len(answers_per_category[category]) == 0: - print(f"Category {category} specified in your config doesn't have corresponding answers!") - continue - - if self.params[category].get("UniEval", None) is None: - continue - - if self.params[category]["UniEval"] and self.language == "cn": - raise Exception( - "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file." - ) - - category_metrics = self.params[category]["UniEval"] - - for task, metric in [tuple(category_metric.split("-")) for category_metric in category_metrics]: - if self.unieval_metric_stats.get(task, None) is None: - self.unieval_metric_stats[task] = {category: {metric: 0}} - elif self.unieval_metric_stats[task].get(category, None) is None: - self.unieval_metric_stats[task][category] = {metric: 0} - else: - self.unieval_metric_stats[task][category][metric] = 0 - - for task in self.unieval_metric_stats: - if self.path_for_UniEval is None: - raise Exception(f"Please specify the path for UniEval model in the config file!") - - if self.path_for_UniEval.get(task, None) is None: - raise Exception(f"Please specify the model path for task {task} in the config file!") - - print(f"Load UniEval model for task {task}.") - - uni_evaluator = unieval.get_evaluator(task, model_name_or_path=self.path_for_UniEval[task]) - for category in self.unieval_metric_stats[task]: - targets_list = [ - target["target"] if target["target"] else target["output"] - for target in targets_per_category[category] - ] - predicts_list = [answer["output"] for answer in answers_per_category[category]] - sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]] - - data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list) - scores = uni_evaluator.evaluate( - data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False - ) - avg_scores = unieval.calculate_average_score(scores) - - self.unieval_metric_stats[task][category].update(avg_scores) - - # gpt evaluation - for category in self.params: - if len(answers_per_category[category]) == 0: - print(f"Category {category} specified in your config doesn't have corresponding answers!") - continue - - if self.params[category].get("GPT", None) is None: - continue - - category_metrics = self.params[category]["GPT"] - - prompt = self.gpt_evaluation_prompt.get(category, None) - if prompt is None: - print(f"No prompt for category {category}! Use prompt for category general now.") - prompt = self.gpt_evaluation_prompt["general"] - - self.gpt_evaluation_results[category] = gpt_evaluate.evaluate( - answers_per_category[category], - prompt, - category_metrics, - category, - self.gpt_model, - self.language, - references=targets_per_category[category] if self.gpt_with_reference else None, - ) - - def save(self, path: str, model_name_list: List[str]) -> None: - """ - Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. - - """ - - if len(model_name_list) == 2: - save_path = os.path.join(path, "gpt_evaluate", "battle_results") - gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path) - else: - if self.automatic_metric_stats: - # Save evaluation results for automatic metrics - automatic_base_save_path = os.path.join(path, "automatic_results") - automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results") - - save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path) - - # Save charts and csv. - automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses") - analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path) - - if self.unieval_metric_stats: - # Save evaluation results for UniEval metrics - unieval_base_save_path = os.path.join(path, "unieval_results") - unieval_results_save_path = os.path.join(unieval_base_save_path, "evaluation_results") - - unieval.save_unieval_results(model_name_list[0], self.unieval_metric_stats, unieval_results_save_path) - - # Save charts and csv. - unieval_analyses_save_path = os.path.join(unieval_base_save_path, "evaluation_analyses") - unieval.analyze_unieval_results(unieval_results_save_path, unieval_analyses_save_path) - - if self.gpt_evaluation_results: - # Save evaluation results for GPT evaluation metrics. - gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") - gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") - - all_evaluations = gpt_evaluate.save_gpt_evaluation_results( - model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path - ) - - # Start to calculate scores and save statistics. - gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") - gpt_evaluate.save_gpt_evaluation_statistics( - model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path - ) - - # Save charts and csv. - gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") - gpt_evaluate.analyze_gpt_evaluation_statistics( - gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path - ) diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py deleted file mode 100644 index 85ee4de53725..000000000000 --- a/applications/Chat/evaluate/metrics.py +++ /dev/null @@ -1,254 +0,0 @@ -import statistics -from typing import Dict, List - -import jieba -from bert_score import score -from nltk.translate.bleu_score import sentence_bleu -from nltk.translate.chrf_score import sentence_chrf -from rouge_chinese import Rouge as Rouge_cn -from rouge_score import rouge_scorer as Rouge_en -from sklearn.metrics import f1_score, precision_score, recall_score -from utils import preprocessing_text, remove_redundant_space - - -def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate BLEU Score Metric - - The calculation includes BLEU-1 for unigram, BLEU-2 for bigram, - BLEU-3 for trigram and BLEU-4 for 4-gram. Unigram evaluates the - accuracy in word level, other n-gram evaluate the fluency in - sentence level. - """ - bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0} - cumulative_bleu = [0] * 4 - weights = [ - (1.0 / 1.0, 0.0, 0.0, 0.0), - (1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0), - (1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0), - (1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0), - ] - - for pred, target in zip(preds, targets): - if language == "cn": - pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() - target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()] - elif language == "en": - pred_list = preprocessing_text(pred).split() - target_list = [preprocessing_text(target).split()] - - bleu = sentence_bleu(target_list, pred_list, weights=weights) - cumulative_bleu = [a + b for a, b in zip(cumulative_bleu, bleu)] - - for i in range(len(cumulative_bleu)): - bleu_scores[f"bleu{i+1}"] = cumulative_bleu[i] / len(preds) - - return bleu_scores - - -def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate CHRF Score Metric in sentence level.""" - chrf_score = {"chrf": 0} - cumulative_chrf = [] - - for pred, target in zip(preds, targets): - if language == "cn": - pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() - target_list = " ".join(jieba.cut(preprocessing_text(target))).split() - elif language == "en": - pred_list = preprocessing_text(pred).split() - target_list = preprocessing_text(target).split() - - cumulative_chrf.append(sentence_chrf(target_list, pred_list)) - - chrf_score["chrf"] = statistics.mean(cumulative_chrf) - - return chrf_score - - -def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]: - """Calculate Chinese ROUGE Score Metric - - The calculation includes ROUGE-1 for unigram, ROUGE-2 for bigram - and ROUGE-L. ROUGE-N evaluates the number of matching n-grams between - the preds and targets. ROUGE-L measures the number of matching - longest common subsequence (LCS) between preds and targets. - """ - rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0} - all_preds = [] - all_targets = [] - - for pred, target in zip(preds, targets): - pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred)))) - target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target)))) - all_preds.append(pred_list) - all_targets.append(target_list) - - rouge_cn = Rouge_cn() - rouge_avg = rouge_cn.get_scores(all_preds, all_targets, avg=True) - - rouge_scores["rouge1"] = rouge_avg["rouge-1"]["f"] - rouge_scores["rouge2"] = rouge_avg["rouge-2"]["f"] - rouge_scores["rougeL"] = rouge_avg["rouge-l"]["f"] - - return rouge_scores - - -def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]: - """Calculate English ROUGE Score Metric - - The calculation includes ROUGE-1 for unigram, ROUGE-2 for bigram - and ROUGE-L. ROUGE-N evaluates the number of matching n-grams between - the preds and targets. ROUGE-L measures the number of matching - longest common subsequence (LCS) between preds and targets. - """ - rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0} - - rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False) - - for pred, target in zip(preds, targets): - score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target)) - rouge_scores["rouge1"] += score["rouge1"].fmeasure - rouge_scores["rouge2"] += score["rouge2"].fmeasure - rouge_scores["rougeL"] += score["rougeL"].fmeasure - - rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds) - rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds) - rouge_scores["rougeL"] = rouge_scores["rougeL"] / len(preds) - - return rouge_scores - - -def rouge_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate ROUGE Score Metric""" - if language == "cn": - return rouge_cn_score(preds, targets) - elif language == "en": - return rouge_en_score(preds, targets) - - -def distinct_score(preds: List[str], language: str) -> Dict[str, float]: - """Calculate Distinct Score Metric - - This metric refers to https://arxiv.org/abs/1510.03055. - It evaluates the diversity of generation text by counting - the unique n-grams. - """ - distinct_score = {"distinct": 0} - cumulative_distinct = [] - - for pred in preds: - if language == "cn": - pred_seg_list = " ".join(jieba.cut(pred)).split() - count_segs = len(pred_seg_list) - unique_segs = set(pred_seg_list) - count_unique_chars = len(unique_segs) - # prevent denominator from being 0 - cumulative_distinct.append(count_unique_chars / (count_segs + 1e-6)) - elif language == "en": - # calculate distinct 1-gram, 2-gram, 3-gram - unique_ngram = [set() for _ in range(0, 3)] - all_ngram_count = [0 for _ in range(0, 3)] - - split_pred = preprocessing_text(pred).split() - for n in range(0, 3): - for i in range(0, len(split_pred) - n): - ngram = " ".join(split_pred[i : i + n + 1]) - unique_ngram[n].add(ngram) - all_ngram_count[n] += 1 - - # Sometimes the answer may contain only one word. For 2-gram and 3-gram, the gram count(denominator) may be zero. - avg_distinct = [len(a) / (b + 1e-6) for a, b in zip(unique_ngram, all_ngram_count)] - - cumulative_distinct.append(statistics.mean(avg_distinct)) - - distinct_score["distinct"] = statistics.mean(cumulative_distinct) - - return distinct_score - - -def bert_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate BERTScore Metric - - The BERTScore evaluates the semantic similarity between - tokens of preds and targets with BERT. - """ - bert_score = {"bert_score": 0} - pred_list = [] - target_list = [] - - for pred, target in zip(preds, targets): - pred_list.append(pred) - target_list.append(target) - - if language == "cn": - _, _, F = score(pred_list, target_list, lang="zh", verbose=True) - elif language == "en": - _, _, F = score(pred_list, target_list, lang="en", verbose=True) - - bert_score["bert_score"] = F.mean().item() - - return bert_score - - -def calculate_precision_recall_f1(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Precision, Recall and F1-Score Calculation - - The calculation of precision, recall and f1-score is realized by counting - the number f overlaps between the preds and target. The comparison length - limited by the shorter one of preds and targets. - """ - precision_recall_f1 = {"precision": 0, "recall": 0, "f1_score": 0} - precision_scores = [] - recall_scores = [] - f1_scores = [] - - for pred, target in zip(preds, targets): - if language == "cn": - pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()] - target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()] - elif language == "en": - pred_list = [char for char in preprocessing_text(pred).split()] - target_list = [char for char in preprocessing_text(target).split()] - - target_labels = [1] * min(len(target_list), len(pred_list)) - pred_labels = [int(pred_list[i] == target_list[i]) for i in range(0, min(len(target_list), len(pred_list)))] - - precision_scores.append(precision_score(target_labels, pred_labels, zero_division=0)) - recall_scores.append(recall_score(target_labels, pred_labels, zero_division=0)) - f1_scores.append(f1_score(target_labels, pred_labels, zero_division=0)) - - precision_recall_f1["precision"] = statistics.mean(precision_scores) - precision_recall_f1["recall"] = statistics.mean(recall_scores) - precision_recall_f1["f1_score"] = statistics.mean(f1_scores) - - return precision_recall_f1 - - -def precision(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate Precision Metric - - Calculating precision by counting the number of overlaps between the preds and target. - """ - precision = {"precision": 0} - precision["precision"] = calculate_precision_recall_f1(preds, targets, language)["precision"] - return precision - - -def recall(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate Recall Metric - - Calculating recall by counting the number of overlaps between the preds and target. - """ - recall = {"recall": 0} - recall["recall"] = calculate_precision_recall_f1(preds, targets, language)["recall"] - return recall - - -def F1_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate F1-score Metric - - Calculating f1-score by counting the number of overlaps between the preds and target. - """ - f1 = {"f1_score": 0} - f1["f1_score"] = calculate_precision_recall_f1(preds, targets, language)["f1_score"] - return f1 diff --git a/applications/Chat/evaluate/requirements.txt b/applications/Chat/evaluate/requirements.txt deleted file mode 100644 index 27d317ed88cc..000000000000 --- a/applications/Chat/evaluate/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -jieba -bert-score -rouge_chinese -scikit-metrics -nltk -openai -seaborn -pandas -matplotlib -numpy -zhon -rouge_score diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py deleted file mode 100644 index 6ffccdaa0819..000000000000 --- a/applications/Chat/evaluate/unieval/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .evaluator import get_evaluator -from .utils import ( - analyze_unieval_results, - calculate_average_score, - convert_data_to_unieval_format, - save_unieval_results, -) - -__all__ = [ - "get_evaluator", - "convert_data_to_unieval_format", - "calculate_average_score", - "save_unieval_results", - "analyze_unieval_results", -] diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py deleted file mode 100644 index bf2bc33a95c0..000000000000 --- a/applications/Chat/evaluate/unieval/evaluator.py +++ /dev/null @@ -1,329 +0,0 @@ -# MIT License - -# Copyright (c) 2022 Ming Zhong - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import numpy as np -from nltk import sent_tokenize - -from .scorer import UniEvaluator -from .utils import add_question - - -class SumEvaluator: - def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): - """Set up evaluator for text summarization""" - self.scorer = UniEvaluator( - model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir, - ) - self.task = "summarization" - self.dimensions = ["coherence", "consistency", "fluency", "relevance"] - - def evaluate(self, data, category, dims=None, overall=True): - """ - Get the scores of all the given dimensions - - category: The category to be evaluated. - - dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate - four dimensions: coherence, consistency, fluency, relevance. - - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. - """ - n_data = len(data) - eval_scores = [{} for _ in range(n_data)] - - if dims == None: - eval_dims = self.dimensions - else: - assert isinstance(dims, list) - eval_dims = dims - - for dim in eval_dims: - # Calculate average sentence-level scores for 'consistency' and 'fluency' - if dim == "consistency" or dim == "fluency": - src_list, output_list = [], [] - n_sents = [] # the number of sentences in each generated summary - for i in range(n_data): - source = data[i]["source"] - system_outputs = sent_tokenize(data[i]["system_output"]) - n_sents.append(len(system_outputs)) - for j in range(len(system_outputs)): - src_list.append(source) - output_list.append(system_outputs[j]) - input_list = add_question(dimension=dim, output=output_list, src=src_list, task=self.task) - sent_score = self.scorer.score(input_list, self.task, category, dim) - - # Get average score for each sample - start_idx = 0 - score = [] - for cur_n_sent in n_sents: - # prevent denominator from being 0 - score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) - start_idx += cur_n_sent - - # Calculate summary-level score for 'coherence' and 'relevance' - elif dim == "coherence" or dim == "relevance": - src_list, output_list, ref_list = [], [], [] - for i in range(n_data): - src_list.append(data[i]["source"]) - output_list.append(data[i]["system_output"]) - if dim == "relevance": - ref_list.append(data[i]["reference"]) - input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task) - score = self.scorer.score(input_list, self.task, category, dim) - - # Please customize other dimensions here for summarization - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. \ - Please customize it first." - ) - - for i in range(n_data): - eval_scores[i][dim] = score[i] - - # Customize your overall score here. - if overall == True: - for i in range(n_data): - eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) - - return eval_scores - - -class DialogEvaluator: - def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): - """Set up evaluator for dialogues""" - self.scorer = UniEvaluator( - model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir, - ) - self.task = "dialogue" - self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"] - - def evaluate(self, data, category, dims=None, overall=True): - """ - Get the scores of all the given dimensions - - category: The category to be evaluated. - - dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate - five dimensions: naturalness, coherence, engagingness, groundedness and understandability. - - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. - """ - n_data = len(data) - eval_scores = [{} for _ in range(n_data)] - - if dims == None: - eval_dims = self.dimensions - else: - assert isinstance(dims, list) - eval_dims = dims - - for dim in eval_dims: - # Calculate summation score for 'engagingness' - if dim == "engagingness": - src_list, output_list, context_list = [], [], [] - n_sents = [] # the number of sentences in each generated response - for i in range(n_data): - source = data[i]["source"] - context = data[i]["context"] - system_outputs = sent_tokenize(data[i]["system_output"]) - n_sents.append(len(system_outputs)) - for j in range(len(system_outputs)): - src_list.append(source) - context_list.append(context) - output_list.append(system_outputs[j]) - input_list = add_question( - dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task - ) - sent_score = self.scorer.score(input_list, self.task, category, dim) - - # Get the summation score for each sample - start_idx = 0 - score = [] - for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx : start_idx + cur_n_sent])) - start_idx += cur_n_sent - - # Calculate turn-level score for other dimensions - elif dim in ["naturalness", "coherence", "groundedness", "understandability"]: - src_list, output_list, context_list = [], [], [] - for i in range(n_data): - src_list.append(data[i]["source"]) - output_list.append(data[i]["system_output"]) - context_list.append(data[i]["context"]) - input_list = add_question( - dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task - ) - score = self.scorer.score(input_list, self.task, category, dim) - - # Please customize other dimensions here for summarization - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. \ - Please customize it first." - ) - - for i in range(n_data): - eval_scores[i][dim] = score[i] - - # Customize your overall score here. - if overall == True: - for i in range(n_data): - eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) - - return eval_scores - - -class D2tEvaluator: - def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): - """Set up evaluator for data-to-text""" - self.scorer = UniEvaluator( - model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir, - ) - self.task = "data2text" - self.dimensions = ["naturalness", "informativeness"] - - def evaluate(self, data, category, dims=None, overall=True): - """ - Get the scores of all the given dimensions - - category: The category to be evaluated. - - dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate - two dimensions: naturalness and informativeness. - - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. - """ - n_data = len(data) - eval_scores = [{} for _ in range(n_data)] - - if dims == None: - eval_dims = self.dimensions - else: - assert isinstance(dims, list) - eval_dims = dims - - for dim in eval_dims: - output_list, ref_list = [], [] - for i in range(n_data): - output_list.append(data[i]["system_output"]) - ref_list.append(data[i]["reference"]) - - input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task) - score = self.scorer.score(input_list, self.task, category, dim) - - for i in range(n_data): - eval_scores[i][dim] = score[i] - - # Customize your overall score here. - if overall == True: - for i in range(n_data): - eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) - - return eval_scores - - -class FactEvaluator: - def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): - """Set up evaluator for factual consistency detection""" - self.scorer = UniEvaluator( - model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir, - ) - self.task = "fact" - self.dim = "consistency" - - def evaluate(self, data, category): - """ - Get the factual consistency score (only 1 dimension for this task) - - category: The category to be evaluated. - """ - n_data = len(data) - eval_scores = [{} for _ in range(n_data)] - - # Calculate average sentence-level scores for factual consistency - src_list, output_list = [], [] - n_sents = [] # the number of sentences in the claim - for i in range(n_data): - source = data[i]["source"] - system_outputs = sent_tokenize(data[i]["system_output"]) - n_sents.append(len(system_outputs)) - for j in range(len(system_outputs)): - src_list.append(source) - output_list.append(system_outputs[j]) - input_list = add_question(dimension=self.dim, output=output_list, src=src_list, task=self.task) - sent_score = self.scorer.score(input_list, self.task, category, self.dim) - - # Get average score for each sample - start_idx = 0 - score = [] - for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent) - start_idx += cur_n_sent - - for i in range(n_data): - eval_scores[i][self.dim] = score[i] - - return eval_scores - - -def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None): - assert task in ["summarization", "dialogue", "data2text", "fact"] - if task == "summarization": - return SumEvaluator( - model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir - ) - elif task == "dialogue": - return DialogEvaluator( - model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir - ) - elif task == "data2text": - return D2tEvaluator( - model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir - ) - elif task == "fact": - return FactEvaluator( - model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir - ) - else: - raise NotImplementedError( - "Other tasks are not implemented, \ - please customize specific tasks here." - ) diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py deleted file mode 100644 index 45706b833205..000000000000 --- a/applications/Chat/evaluate/unieval/scorer.py +++ /dev/null @@ -1,96 +0,0 @@ -# MIT License - -# Copyright (c) 2022 Ming Zhong - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch -import torch.nn as nn -from tqdm import tqdm -from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer - - -class UniEvaluator: - def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): - """Set up model""" - self.device = device - self.max_length = max_length - - self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) - self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir) - - self.model.eval() - self.model.to(device) - - self.softmax = nn.Softmax(dim=1) - - self.pos_id = self.tokenizer("Yes")["input_ids"][0] - self.neg_id = self.tokenizer("No")["input_ids"][0] - - def score(self, inputs, task, category, dim, batch_size=8): - """ - Get scores for the given samples. - final_score = postive_score / (postive_score + negative_score) - """ - - # The implementation of "forward" in T5 still requires decoder_input_ids. - # Therefore, we construct a random one-word target sequence. - # The content of the target has no effect on the final scores. - tgts = ["No" for _ in range(len(inputs))] - - pos_score_list, neg_score_list = [], [] - for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "): - src_list = inputs[i : i + batch_size] - tgt_list = tgts[i : i + batch_size] - try: - with torch.no_grad(): - encoded_src = self.tokenizer( - src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" - ) - encoded_tgt = self.tokenizer( - tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" - ) - - src_tokens = encoded_src["input_ids"].to(self.device) - src_mask = encoded_src["attention_mask"].to(self.device) - - tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1) - - output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens) - logits = output.logits.view(-1, self.model.config.vocab_size) - - pos_score = self.softmax(logits)[:, self.pos_id] # Yes - neg_score = self.softmax(logits)[:, self.neg_id] # No - - cur_pos_score = [x.item() for x in pos_score] - cur_neg_score = [x.item() for x in neg_score] - pos_score_list += cur_pos_score - neg_score_list += cur_neg_score - - except RuntimeError: - print(f"source: {src_list}") - print(f"target: {tgt_list}") - exit(0) - - score_list = [] - for i in range(len(pos_score_list)): - score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i])) - - return score_list diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py deleted file mode 100644 index 46b0f2907a30..000000000000 --- a/applications/Chat/evaluate/unieval/utils.py +++ /dev/null @@ -1,285 +0,0 @@ -# MIT License - -# Copyright (c) 2022 Ming Zhong - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import os -from typing import Dict - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -import tqdm - - -def add_question(dimension, output, src=None, ref=None, context=None, task=None): - """ - Add questions to generate input in Bool-QA format for UniEval. - - dimension: specific dimension to be evaluated - src: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation. - output: output text generated by the models - ref: human-annotated groundtruth - context: the context needed to evaluate several specific dimension. For example, - additional factual information when evaluating engagingness and groundedness in dialogues. - """ - - input_with_question = [] - for i in range(len(output)): - # For summarization - if task == "summarization": - if dimension == "fluency": - cur_input = "question: Is this a fluent paragraph? paragraph: " + output[i] - elif dimension == "coherence": - cur_input = ( - "question: Is this a coherent summary to the document? summary: " - + output[i] - + " document: " - + src[i] - ) - elif dimension == "consistency": - cur_input = ( - "question: Is this claim consistent with the document? claim: " - + output[i] - + " document: " - + src[i] - ) - elif dimension == "relevance": - cur_input = ( - "question: Is this summary relevant to the reference? summary: " - + output[i] - + " reference: " - + ref[i] - ) - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. Please customize it first." - ) - # For dialogues - elif task == "dialogue": - if dimension == "naturalness": - cur_input = "question: Is this a natural response in the dialogue? response: " + output[i] - elif dimension == "coherence": - cur_input = ( - "question: Is this a coherent response given the dialogue history? response: " - + output[i] - + " dialogue history: " - + src[i] - ) - elif dimension == "engagingness": - cur_input = ( - "question: Is this an engaging and informative response according to the dialogue history and fact? response: " - + output[i] - + " dialogue history: " - + src[i] - + " fact: " - + context[i] - ) - elif dimension == "groundedness": - cur_input = ( - "question: Is this response consistent with knowledge in the fact? response: " - + output[i] - + " fact: " - + context[i] - ) - elif dimension == "understandability": - cur_input = "question: Is this an understandable response in the dialogue? response: " + output[i] - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. Please customize it first." - ) - # For data-to-text - elif task == "data2text": - if dimension == "naturalness": - cur_input = "question: Is this a fluent utterance? utterance: " + output[i] - elif dimension == "informativeness": - cur_input = ( - "question: Is this sentence informative according to the reference? sentence: " - + output[i] - + " reference: " - + ref[i] - ) - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. Please customize it first." - ) - # For factual consistency detection - elif task == "fact": - if dimension == "consistency": - cur_input = ( - "question: Is this claim consistent with the document? claim: " - + output[i] - + " document: " - + src[i] - ) - else: - raise NotImplementedError("No other dimensions for the factual consistency detection task.") - # For new customized tasks - else: - raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.") - input_with_question.append(cur_input) - return input_with_question - - -def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None): - """ - Convert the data into the unieval's format. - - output_list: a list of model output - - src_list: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation - ref_list: human-annotated groundtruth - """ - json_data = [] - for i in range(len(output_list)): - cur = {} - cur["system_output"] = output_list[i] - if src_list is not None: - cur["source"] = src_list[i] - if ref_list is not None: - cur["reference"] = ref_list[i] - cur["context"] = "" - json_data.append(cur) - return json_data - - -def calculate_average_score(scores): - """ - Calculate average scores for different metrics - - scores: a list of scores for different metrics for each answer - - """ - metrics = {metric: 0 for metric in scores[0]} - - for score in scores: - for metric in score: - metrics[metric] += score[metric] - - for metric in metrics: - metrics[metric] /= len(scores) - - return metrics - - -def save_unieval_results(model_name: str, unieval_metric_stats: Dict[str, Dict], save_path: str) -> None: - """ - Save UniEval evaluation results of different categories for one model. - - """ - - if not os.path.exists(save_path): - os.makedirs(save_path) - - unieval_metric_stats_per_category = {} - for task, category_stat in unieval_metric_stats.items(): - for category, metric_stat in category_stat.items(): - if unieval_metric_stats_per_category.get(category, None) is None: - unieval_metric_stats_per_category[category] = {} - for metric, score in metric_stat.items(): - unieval_metric_stats_per_category[category][f"{metric}-{task}"] = score - - automatic_df = pd.DataFrame(unieval_metric_stats_per_category) - automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True) - - -def read_unieval_results(results_path: str, file_name: str) -> Dict[str, Dict]: - """ - Read a csv file and return a dictionary which stores scores per metric. - - """ - - results = pd.read_csv(os.path.join(results_path, file_name), index_col=0) - - results_dict = {metric: {} for metric in list(results.index)} - for i, metric in enumerate(results_dict.keys()): - for j, category in enumerate(list(results.columns)): - if pd.isnull(results.iloc[i][j]): - continue - results_dict[metric][category] = results.iloc[i][j] - - return results_dict - - -def analyze_unieval_results(results_path: str, save_path: str) -> None: - """ - Analyze and visualize all csv files in the given folder. - - """ - - if not os.path.exists(results_path): - raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!') - - all_statistics = {} - - for file_name in os.listdir(results_path): - if file_name.endswith("_results.csv"): - model_name = file_name.split("_results.csv")[0] - all_statistics[model_name] = read_unieval_results(results_path, file_name) - - if len(list(all_statistics.keys())) == 0: - raise Exception(f'There are no csv files in the given directory "{results_path}"!') - - frame_all = {"model": [], "category": [], "metric": [], "score": []} - frame_per_metric = {} - for model_name, model_statistics in all_statistics.items(): - for metric, metric_statistics in model_statistics.items(): - if frame_per_metric.get(metric) is None: - frame_per_metric[metric] = {"model": [], "category": [], "score": []} - - for category, category_score in metric_statistics.items(): - frame_all["model"].append(model_name) - frame_all["category"].append(category) - frame_all["metric"].append(metric) - frame_all["score"].append(category_score) - - frame_per_metric[metric]["model"].append(model_name) - frame_per_metric[metric]["category"].append(category) - frame_per_metric[metric]["score"].append(category_score) - - if not os.path.exists(save_path): - os.makedirs(save_path) - - frame_all = pd.DataFrame(frame_all) - frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv")) - - for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"UniEval metrics: ", - total=len(frame_per_metric.keys()), - ): - data = pd.DataFrame(frame_per_metric[metric]) - - sns.set() - fig = plt.figure(figsize=(16, 10)) - - fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True) - fig.set_title( - f"Comparison between Different Models for Metric {metric.split('-')[0].title()} in Task {metric.split('-')[1].title()}" - ) - plt.xlabel("Evaluation Category") - plt.ylabel("Score") - - figure = fig.get_figure() - figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400) - - plt.close() diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py deleted file mode 100644 index 10df455b69d7..000000000000 --- a/applications/Chat/evaluate/utils.py +++ /dev/null @@ -1,206 +0,0 @@ -import io -import json -import os -import string -from typing import Dict - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -import tqdm -from zhon import hanzi - - -def _make_w_io_base(f, mode: str): - if not isinstance(f, io.IOBase): - f_dirname = os.path.dirname(f) - if f_dirname != "": - os.makedirs(f_dirname, exist_ok=True) - f = open(f, mode=mode) - return f - - -def _make_r_io_base(f, mode: str): - if not isinstance(f, io.IOBase): - f = open(f, mode=mode) - return f - - -def jdump(obj, f, mode="w", indent=4, default=str): - """Dump a str or dictionary to a file in json format. - Args: - obj: An object to be written. - f: A string path to the location on disk. - mode: Mode for opening the file. - indent: Indent for storing json dictionaries. - default: A function to handle non-serializable entries; defaults to `str`. - """ - f = _make_w_io_base(f, mode) - if isinstance(obj, (dict, list)): - json.dump(obj, f, indent=indent, default=default, ensure_ascii=False) - elif isinstance(obj, str): - f.write(obj) - else: - raise ValueError(f"Unexpected type: {type(obj)}") - f.close() - - -def jload(f, mode="r"): - """Load a .json file into a dictionary.""" - f = _make_r_io_base(f, mode) - jdict = json.load(f) - f.close() - return jdict - - -def get_json_list(file_path): - with open(file_path, "r") as f: - json_list = [] - for line in f: - json_list.append(json.loads(line)) - return json_list - - -def get_data_per_category(data, categories): - data_per_category = {category: [] for category in categories} - for item in data: - category = item["category"] - if category in categories: - data_per_category[category].append(item) - - return data_per_category - - -def remove_punctuations(text: str) -> str: - """ - Remove punctuations in the given text. - It is used in evaluation of automatic metrics. - - """ - - punctuation = string.punctuation + hanzi.punctuation - punctuation = set([char for char in punctuation]) - punctuation.difference_update(set("!@#$%&()<>?|,.\"'")) - - out = [] - for char in text: - if char in punctuation: - continue - else: - out.append(char) - - return "".join(out) - - -def remove_redundant_space(text: str) -> str: - """ - Remove redundant spaces in the given text. - It is used in evaluation of automatic metrics. - - """ - - return " ".join(text.split()) - - -def preprocessing_text(text: str) -> str: - """ - Preprocess the given text. - It is used in evaluation of automatic metrics. - - """ - - return remove_redundant_space(remove_punctuations(text.lower())) - - -def save_automatic_results(model_name: str, automatic_metric_stats: Dict[str, Dict], save_path: str) -> None: - """ - Save automatic evaluation results of different categories for one model. - - """ - - if not os.path.exists(save_path): - os.makedirs(save_path) - - automatic_df = pd.DataFrame(automatic_metric_stats) - automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True) - - -def read_automatic_results(results_path: str, file_name: str) -> Dict[str, Dict]: - """ - Read a csv file and return a dictionary which stores scores per metric. - - """ - - results = pd.read_csv(os.path.join(results_path, file_name), index_col=0) - - results_dict = {metric: {} for metric in list(results.index)} - for i, metric in enumerate(results_dict.keys()): - for j, category in enumerate(list(results.columns)): - if pd.isnull(results.iloc[i][j]): - continue - results_dict[metric][category] = results.iloc[i][j] - - return results_dict - - -def analyze_automatic_results(results_path: str, save_path: str) -> None: - """ - Analyze and visualize all csv files in the given folder. - - """ - - if not os.path.exists(results_path): - raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!') - - all_statistics = {} - - for file_name in os.listdir(results_path): - if file_name.endswith("_results.csv"): - model_name = file_name.split("_results.csv")[0] - all_statistics[model_name] = read_automatic_results(results_path, file_name) - - if len(list(all_statistics.keys())) == 0: - raise Exception(f'There are no csv files in the given directory "{results_path}"!') - - frame_all = {"model": [], "category": [], "metric": [], "score": []} - frame_per_metric = {} - for model_name, model_statistics in all_statistics.items(): - for metric, metric_statistics in model_statistics.items(): - if frame_per_metric.get(metric) is None: - frame_per_metric[metric] = {"model": [], "category": [], "score": []} - - for category, category_score in metric_statistics.items(): - frame_all["model"].append(model_name) - frame_all["category"].append(category) - frame_all["metric"].append(metric) - frame_all["score"].append(category_score) - - frame_per_metric[metric]["model"].append(model_name) - frame_per_metric[metric]["category"].append(category) - frame_per_metric[metric]["score"].append(category_score) - - if not os.path.exists(save_path): - os.makedirs(save_path) - - frame_all = pd.DataFrame(frame_all) - frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv")) - - for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"automatic metrics: ", - total=len(frame_per_metric.keys()), - ): - data = pd.DataFrame(frame_per_metric[metric]) - - sns.set() - fig = plt.figure(figsize=(16, 10)) - - fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True) - fig.set_title(f"Comparison between Different Models for Metric {metric.title()}") - plt.xlabel("Evaluation Category") - plt.ylabel("Score") - - figure = fig.get_figure() - figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400) - - plt.close() diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md new file mode 100644 index 000000000000..06c6962f7978 --- /dev/null +++ b/applications/ColossalEval/README.md @@ -0,0 +1,554 @@ +# ColossalEval + +## Table of Contents + +- [Overview](#overview) +- [Leaderboard](#leaderboard) +- [Install](#install) +- [Evaluation Process](#evaluation-process) + - [Inference](#inference) + - [Dataset Preparation](#dataset-preparation) + - [Configuration](#configuration) + - [How to Use](#how-to-use) + - [Evaluation](#evaluation) + - [Dataset Evaluation](#dataset-evaluation) + - [Configuration](#dataset-evaluation) + - [How to Use](#dataset-evaluation) + - [GPT Evaluation](#gpt-evaluation) + - [Configuration](#gpt-evaluation) + - [How to Use](#gpt-evaluation) +- [More Details](#more-details) + - [Inference Details](#inference-details) + - [Evaluation Details](#evaluation-details) + - [Metrics](#metrics) + - [examples](#examples) + - [Dataset Evaluation Example](#dataset-evaluation-example) + - [GPT Evaluation Example](#gpt-evaluation-example) +- [To Do](#to-do) +- [FAQ](#faq) + - [How to Add a New Metric?](#how-to-add-a-new-metric) + - [How to Add a New Dataset?](#how-to-add-a-new-dataset) + - [How to Add a New Model?](#how-to-add-a-new-model) +- [Citations](#citations) + +## Overview +[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. More details can be found in the following sections. + +## Leaderboard + +We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models. + +- We use 5-shot for MMLU and calculate scores based on the logits of first predicted token. +- We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token. +- We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score. +- We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token. +- The generation config for all dataset is greedy search. +- We also provided CEval scores from its lastest leaderboard or the official repository of the model. + +More details about metrics can be found in [Metrics](#metrics). + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :----------------------------: | +| | - | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.20 | + +> The score in parentheses corresponds to the scores in the official repository of the model. +> +> We use zero-shot for ChatGLM models. +> +> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. +> +> For other models and other dataset, we calculate logits over "A", "B", "C" and "D". + +Our model achieves a much better score over all other Llama-1 or Llama-2 based models and also stands out among popular open source LLMs. + +## Install +You should install `ColossalEval` in order to use it and `colossal_eval` is the package installed. +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/applications/ColossalEval +pip install . +``` +If you want to add customized dataset or models, use `pip install -e .` in stead to ensure that any changes you make to the source code will immediately affect the package you install. + +## Evaluation Process +The evaluation process involves 2 steps which are `inference` and `evaluation`. You need to set the config for each step. + +### Inference + +The inference process consists of two parts. +1. Preprocess and convert the original dataset. +2. Config your tokenizer and model arguments to perform zero-shot or few-shot prompting. + +#### Dataset Preparation + +In this step, the original dataset(either in `csv` or `jsonl` format) will be loaded and converted into a `dict`. In the conversion process, we carefully parse each subcategory and assign specific inference arguments for this subcategory. + +Inference arguments are stored in a `dict`. The following is an example. + +```python +inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32 +} +``` +The `inference_kwargs` currently contains 5 fields: + +- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated +- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None. +- `language` (str, compulsory): The language for the subcategory. +- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. +- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference. + +For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly. + +Other than `inference_kwargs`, `data` is a list containing questions of a same subcategory. The following is a converted dataset. + +```json +{ + "dev": { + "category 1": {"data": [], "inference_kwargs": {}}, + "category 2": {"data": [], "inference_kwargs": {}} + }, + "test": { + "category 1": {"data": [], "inference_kwargs": {}}, + "category 2": {"data": [], "inference_kwargs": {}} + } +} +``` + +A data sample basically follow the format of Alpaca. It should contain the following keys: + +* `dataset` (str, compulsory): The name of the dataset. +* `split` (str, compulsory): The split of the instruction. +* `catrgory` (str, compulsory): The category of the instruction. +* `instruction` (str, compulsory): The instruction for the LLM. +* `input` (str, optional): The additional context of the instruction. +* `output` (str, optional): The model output of the instruction. +* `target` (str, optional): The target answer for the instruction. + +Example: + +```json +{ + "dev": { + "Abstract Algebra": [ + { + "dataset": "mmlu", + "split": "dev", + "category": "Abstract Algebra", + "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.", + "input": "Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer: ", + "output": "", + "target": "B" + }, + ] + }, + "test": { + "Abstract Algebra": [ + { + "dataset": "mmlu", + "split": "test", + "category": "Abstract Algebra", + "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.", + "input": "Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.\nA. 0\nB. 4\nC. 2\nD. 6\nAnswer: ", + "output": "", + "target": "B" + }, + ] + } +} +``` + +#### Configuration +In this step, you will configure your tokenizer and model arguments to infer on the given datasets. + +A config file consists of two parts. +1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. +2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. + +Once you have all config ready, the program will run inference on all the given datasets on all the given models. + +An example config using model class `HuggingFaceCausalLM` and dataset class `CMMLUDataset` can be: +```json +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 2048, + "tokenizer_path": "path to tokenizer", + "tokenizer_kwargs": { + "use_fast": false, + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "dataset name", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset", + "save_path": "path to save converted dataset" + } + ] +} +``` + +Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. + +#### How to Use +An example script can be the following. The `configs/dataset_evaluation/inference.py` is the same in all examples provided. + +```shell +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" +``` + +You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. + +### Evaluation + +In the evaluation process, you only need to configure your evaluation parameters. You can use either public dataset or help from GPTs to do evaluation. We will introduce configuration for dataset evaluation and GPT evaluation. + +#### Dataset Evaluation + +In dataset evaluation, we calculate different metrics on the given inference results and public dataset. + +##### Configuration + +A config file for dataset evaluation consists of two parts. +1. Model config. In model config, you need to specify model name. If you want to evaluate perplexity over a pretrain dataset and calculate per-byte-perplexity, you have to add your tokenizer config and model max length. +2. Dataset config. In dataset config, you need to specify the evaluation arguments for the dataset. + +Once you have all config ready, the program will run evaluation on inference results for all given models and dataset. + +An example config can be: +```json +{ + "model": [ + { + "name": "model name" + } + ], + "dataset": [ + { + "name": "dataset name", + "metrics": ["first_token_accuracy"] + } + ] +} +``` + +The above config specifies that the program will evaluate the inference results using `first_token_accuracy` metric. + +##### How to Use + +An example script can be the following. + +```shell +python eval_dataset.py \ + --config "path to config file" \ + --inference_results_path "path to inference results" \ + --evaluation_results_save_path "path to save evaluation results" +``` + +You should specify the path to config file in `config`, the path to inference results in `inference_results_path` and the path to save evaluation results in `evaluation_save_path`. + +#### GPT Evaluation + +In GPT evaluation, we provide a prompt template which can fit in different pre-defined metrics with Chain-of-Thoughts. In the following sections, we will only introduce how you can evaluate model answers using GPTs. More details can be found in `colossal_eval/evaluate/GPT Evaluation.md`. + +##### Configuration + +The following is an example of a English config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics. You can find an example English config file in `configs/gpt_evaluation`. + +```json +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + } +} +``` + +##### How to Use +After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. + +An example script is provided as follows: + +```shell +python eval.py \ + --config_file "path to the config file" \ + --battle_prompt_file "path to the prompt file for battle" \ + --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ + --target_file "path to the target answer file" \ + --answer_file_list "path to the answer file" \ + --model_name_list "the names of the model" \ + --gpt_model "which GPT model to use for evaluation" \ + --save_path "path to save results" \ + --openai_key "your openai key" \ +``` + +## More Details + +### Inference + +In the inference process, we will do generation, calculate loss over target tokens, calculate number of target tokens, softmax over given options (for example, "A", "B", "C", and "D") according to the inference arguments. + +For tokenization, we adopt tokenization strategy in [LongBench](https://github.com/THUDM/LongBench/blob/main/pred.py#L55) to preserve crucial instructions on the left and right side and keep all target tokens. + +For labeling target tokens, we adopt method from [FastChat](https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L137), but it doesn't always hold true due to tokenizers' different behavior. We plan to insert special tokens to correctly label the target tokens. + +For calculating loss, we return per-sample-loss instead of per-batch-loss if we directly use `model(batch).loss` provided in HuggingFace. + +### Evaluation + +To make it more easier to set the config, you only need to specify all metrics you want to use in key `metrics`. However, the program will only use a subset of metrics you give for different subcategories. Applying all metrics to all subcategories is obviously unsuitable. The suggested metrics for specific categories should be defined in `colossal_eval/evaluate/dataset_evaluator/metrics.py`. + +#### Metrics + +- `combined_single_choice_accuracy`: A combination of `first_token_logit` and `single_choice_accuracy`. If one of these is correct, the model will get the score. It can be used in all dataset that contains single-choice questions. +- `first_token_logit`: Calculate score based on softmax score over the given choices. If the argmax of the softmax is equal to the reference, the model will get the score. If there is `NaN` in softmax score, it will calculate the score using exact match. It can be used in all dataset that contains single-choice questions. +- `single_choice_accuracy`: Calculate score using exact match. It will only get the first uppercase letter such as A, B, C or D that is not surrouded by lowercase letters. If the uppercase letter is equal to the reference, the model will get the score. It can be used in all dataset that contains single-choice questions. +- `multi_choice_accuracy`: Calculate score on multi-choice questions. It will get a set of all uppercase letters such as A, B, C or D that is not surrouded by lowercase letters. If the prediction conatains uppercase letters that are not in reference. The model will get 0 score. If the prediction contains a uppercase letter that is in reference, the model will get a score of `1/len(reference)`. It is used in AGIEval and GAOKAO-Bench. +- `math_equivalence`: Code from [hendrycks](https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py). Compute scores over the prediction math formula and reference math formula. It is used in AGIEval and GAOKAO-Bench. +- `f1_score`: Calculate English f1 score between prediction and reference. It is used in Longbench. +- `f1_zh_score`: Calculate Chinese f1 score between prediction and reference. It is used in Longbench. +- `rouge_score`: Calculate English f1 score between prediction and reference. It is used in GAOKAO-Bench and LongBench. +- `rouge_zh_score`: Calculate Chinese rouge score between prediction and reference. It is used in GAOKAO-Bench and LongBench. +- `retrieval_score`: Calculate English retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench. +- `retrieval_zh_score`: Calculate Chinese retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench. +- `classification_score`: Calculate classification score between prediction and reference. It determines whether the ouput(a class) is equal to the reference. It is used in Longbench. +- `code_sim_score`: Calculate similarity score between prediction and reference. It is used in Longbench. +- `count_score`: Calculate count score between prediction and reference. It determines whether the ouput(number of given passages) is equal to the reference. It is used in Longbench. +- `perplexity`: Calculate perplexity. The formula is $ perplexity = \frac{1}{n} \sum_i e^{loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset. +- `ppl_score`: Calculate perplexity score. The formula is $ ppl\_score = \frac{1}{n} \sum_i e^{-loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset. +- `ppl_score_over_choices`: Calculate perplexity score over choices. The formula is $ ppl\_score\_over\_choices= \frac{1}{n} \sum_i e^{-loss\_over\_choices_i} $ where $n$ is the number of samples and $ loss\_over\_choices_i $ is the loss on the first predicted token for sample $ i $. It can be used in all dataset that contains single-choice questions. +- `per_byte_perplexity`: Calculate per byte perplexity. The formula is $ \frac{1}{n} \sum_i e^{\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset. +- `per_byte_ppl_score`: Calculate per byte perplexity score. The formula is $ \frac{1}{n} \sum_i e^{-\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset. + +We use `combined_single_choice_accuracy` and `first_token_logit` in the leaderboard. + +### Examples + +We provide 2 examples for you to explore our `colossal_eval` package. + +#### Dataset Evaluation Example + +This example is in folder `examples/dataset_evaluation`. + +1. `cd examples/dataset_evaluation` +2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters +3. Run `inference.sh` to get inference results. +4. Fill in your evaluation config file in `config/evaluation/config.json`. Set the model and dataset parameters. +5. Run `eval_dataset.sh` to get evaluation results. + +#### GPT Evaluation Example + +The examples is in folder `examples/gpt_evaluation`. + +1. `cd examples/gpt_evaluation` +2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. If you want to use the example dataset we provide, the dataset is `ColossalDataset`. +3. Run `inference.sh` to get inference results. +4. Fill in your evaluation config file in `config/evaluation/config.json`. +5. Run `eval.sh` to get evaluation results. + +## FAQ + +### How to Add a New Metric? + +If you want to add a customized metric, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new metric, you can follow the example of multi_choice_accuracy in line 339 in `colossal_eval/evaluate/dataset_evaluator/metric.py`. The method take one data sample's prediction and reference as input and return a score ranging from 0 to 1. + +A skeleton of code is the following. + +```python + +def CustomizedMetric(prediction: str, reference: str): + score = xxx + return score +``` + +Once you have successfully added your own metric, you should specify your metric both in `colossal_eval/evaluate/dataset_evaluator/metric.py` (suggest which subcategories shoule the metric be applied to) and your evaluation config. + +### How to Add a New Dataset? + +If you want to add customized dataset, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new dataset, you can follow the example of `colossal_eval/dataset/mmlu.py`. You need to make sure that the format of questions in one subcategory should be the same. For example, all questions should have target answers or all questions should be single-choice questions. + +A skeleton of code is the following. + +```python + +class CustomizedDataset(BaseDataset): + @staticmethod + def load(): + # 1. Load and convert the original dataset format. + # 2. Assign inference arguments for each subcategory. + # 3. Return the converted dataset. + pass +``` + +Once you have successfully added your own dataset, you can specify your dataset class in your inference config. + +### How to Add a New Model? + +If you want to add customized models, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new model, you can follow the example of `colossal_eval/models/huggingface.py`. You need to provide a way to load the model and tokenizer, calculate loss and generate. + +A skeleton of code is the following. + +```python + +class CustomizedModel(BaseModel): + def __init__(self): + super().__init__() + self._load_tokenizer() + self._load_model() + + def _load_tokenizer(): + pass + + def _load_model(): + pass + + def _calculate_loss(): + pass + + def get_loss(): + self._calculate_loss() + + def inference(samples): + # 1. Load samples from the same subcategory. + # 2. Infer in a batch way according to inference arguments. + # 3. Return results. + batch_samples = xxx + self.get_loss(batch_samples) + self.generate(batch_samples) + + return inference_results + + def generate(): + pass +``` + +Once you have successfully added your own model, you can specify your model class in your inference config. + +## To do + +- [ ] Add visualization code for evaluation results on public dataset +- [ ] Improve the way to label target tokens + +## Citations + +```bibtex +@misc{zhong2023agieval, + title={AGIEval: A Human-Centric Benchmark for Evaluating Foundation Models}, + author={Wanjun Zhong and Ruixiang Cui and Yiduo Guo and Yaobo Liang and Shuai Lu and Yanlin Wang and Amin Saied and Weizhu Chen and Nan Duan}, + year={2023}, + eprint={2304.06364}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@article{huang2023ceval, +title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models}, +author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian}, +journal={arXiv preprint arXiv:2305.08322}, +year={2023} +} + +@misc{li2023cmmlu, + title={CMMLU: Measuring massive multitask language understanding in Chinese}, + author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, + year={2023}, + eprint={2306.09212}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@inproceedings{Zhang2023EvaluatingTP, + title={Evaluating the Performance of Large Language Models on GAOKAO Benchmark}, + author={Xiaotian Zhang and Chunyang Li and Yi Zong and Zhengyu Ying and Liang He and Xipeng Qiu}, + year={2023} +} + +@misc{bai2023longbench, + title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding}, + author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li}, + year={2023}, + eprint={2308.14508}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@article{hendryckstest2021, + title={Measuring Massive Multitask Language Understanding}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} + +@article{hendrycks2021ethics, + title={Aligning AI With Shared Human Values}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} + +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +``` diff --git a/applications/ColossalEval/colossal_eval/__init__.py b/applications/ColossalEval/colossal_eval/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py new file mode 100644 index 000000000000..4ea173198f5a --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py @@ -0,0 +1,19 @@ +from .agieval import AGIEvalDataset +from .base import BaseDataset +from .ceval import CEvalDataset +from .cmmlu import CMMLUDataset +from .colossalai import ColossalDataset +from .gaokaobench import GaoKaoBenchDataset +from .longbench import LongBenchDataset +from .mmlu import MMLUDataset + +__all__ = [ + "AGIEvalDataset", + "BaseDataset", + "CEvalDataset", + "CMMLUDataset", + "GaoKaoBenchDataset", + "LongBenchDataset", + "MMLUDataset", + "ColossalDataset", +] diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py new file mode 100644 index 000000000000..92ebd65931ed --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -0,0 +1,247 @@ +# Adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/dataset_loader.py. + +import ast +import glob +import os +from copy import deepcopy +from typing import Dict, List + +import pandas as pd +from colossal_eval.utils import get_json_list + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +# define the datasets +english_qa_datasets = [ + "lsat-ar", + "lsat-lr", + "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + "aqua-rat", + "sat-en-without-passage", + "gaokao-english", +] +chinese_qa_datasets = [ + "logiqa-zh", + "jec-qa-kd", + "jec-qa-ca", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + "gaokao-physics", + "gaokao-mathqa", +] +english_cloze_datasets = ["math"] +chinese_cloze_datasets = ["gaokao-mathcloze"] + +multi_choice_datasets = ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"] +math_output_datasets = {"gaokao-mathcloze", "math"} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict: + """Modified from https://github.com/microsoft/AGIEval/blob/main/src/dataset_loader.py#L190""" + try: + all_classes = None + passage = line["passage"] if line["passage"] is not None else "" + + if dataset_name in english_qa_datasets: + option_string = "ABCDEFG" + count = len(line["options"]) + + input = ( + "Question: " + + line["question"] + + " " + + "Choose from the following options: " + + " ".join(line["options"]) + + "\n" + + "Answer: " + ) + + all_classes = list(option_string[0:count]) + + elif dataset_name in chinese_qa_datasets: + option_string = "ABCDEFG" + count = len(line["options"]) + + input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + + all_classes = list(option_string[0:count]) + + elif dataset_name in english_cloze_datasets: + input = "Question: " + line["question"] + "\n" + "Answer: " + + elif dataset_name in chinese_cloze_datasets: + input = "问题:" + line["question"] + "\n" + "答案:" + + return { + "instruction": input if not passage else passage + "\n\n" + input, + "target": line["label"] if line["label"] else line["answer"], + }, all_classes + + except NameError: + logger.info("Dataset not defined.") + + +# process few-shot raw_prompts +def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False): + skip_passage = False + if dataset_name == "sat-en-without-passage": + skip_passage = True + dataset_name = "sat-en" + demostrations = [] + # read the prompts by context and explanation + context_row = [0, 1, 3, 5, 7, 9] + explanation_row = [0, 2, 4, 6, 8, 10] + raw_prompts_context = pd.read_csv( + prompt_path, header=0, skiprows=lambda x: x not in context_row, keep_default_na=False + ) + raw_prompts_explanation = pd.read_csv( + prompt_path, header=0, skiprows=lambda x: x not in explanation_row, keep_default_na=False + ).replace(r"\n\n", "\n", regex=True) + contexts = [] + for line in list(raw_prompts_context[dataset_name]): + if line: + # print(line) + contexts.append(ast.literal_eval(line)) + explanations = [exp for exp in raw_prompts_explanation[dataset_name] if exp] + + for idx, (con, exp) in enumerate(zip(contexts, explanations)): + passage = con["passage"] if con["passage"] is not None and not skip_passage else "" + question = con["question"] + options = con["options"] if con["options"] is not None else "" + label = con["label"] if con["label"] is not None else "" + answer = con["answer"] if "answer" in con and con["answer"] is not None else "" + + if dataset_name in english_qa_datasets: + question_input = ( + "Question: " + + passage + + " " + + question + + "\n" + + "Choose from the following options: " + + " ".join(options) + + "\n" + + "Answer: {}".format(label) + ) + elif dataset_name in chinese_qa_datasets: + question_input = ( + "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) + ) + elif dataset_name in english_cloze_datasets: + question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) + elif dataset_name in chinese_cloze_datasets: + question_input = "问题:" + question + "\n" + "答案:{}".format(answer) + else: + raise ValueError(f"During loading few-sot examples, found unknown dataset: {dataset_name}") + + if chat_mode: + demostrations.append((question_input,)) + else: + demostrations.append(question_input + "\n") + + return demostrations + + +class AGIEvalDataset(BaseDataset): + """ + Dataset wrapper for AGIEval dataset. + Data source: https://github.com/microsoft/AGIEval + This dataset class will convert the original dataset into the inference dataset. + + A few dirty data needed to be manually corrected in the origin dataset: + Issue link: https://github.com/microsoft/AGIEval/issues/16 + 1. Invalid options in line 190 in gaokao-chemistry.jsonl. + 2. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en-without-passage.jsonl. + 3. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en.jsonl. + 4. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en-without-passage.jsonl. + 5. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en.jsonl. + 6. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en-without-passage.jsonl. + 7. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en.jsonl. + 8. Label is empty in line 212 in jec-qa-kd.jsonl. Content is also dirty. + 9. Actually, gaokao-mathqa.jsonl is also a multi-choice dataset. See line 149 286 287. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + + files = glob.glob(os.path.join(path, "*.jsonl")) + files.sort() + + if few_shot: + prompt_path = os.path.join(path, "few_shot_prompts.csv") + + for file in files: + dataset_name = os.path.basename(file)[0 : -len(".jsonl")] + + few_shot_data = [] + if few_shot: + # process demo once if it is few-shot-CoT + few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False) + + dataset["test"][dataset_name] = {"data": []} + + file_dir = os.path.join(path, file) + + loaded_jsonl = get_json_list(file_dir) + + # It's been tested that each data sample in one subcategory have same inference arguments. + _, all_classes = get_prompt(loaded_jsonl[0], dataset_name, logger) + inference_kwargs = deepcopy(default_inference_kwargs) + if all_classes is not None and dataset_name not in multi_choice_datasets: + inference_kwargs["all_classes"] = all_classes + + if dataset_name in english_qa_datasets: + inference_kwargs["language"] = "English" + if dataset_name in chinese_qa_datasets: + inference_kwargs["language"] = "Chinese" + inference_kwargs["few_shot_data"] = few_shot_data + + dataset["test"][dataset_name]["inference_kwargs"] = inference_kwargs + + for line in loaded_jsonl: + info, all_classes = get_prompt(line, dataset_name, logger) + + # Convert multi-choice answers to a single string. + # We will convert it back when evaluating. + # We do this because if target is a list, it should be only used for multiple target answers. + if dataset_name in multi_choice_datasets: + if isinstance(info["target"], str) and len(info["target"]) > 1: + # "gaokao-mathqa" actually contain multi-choice questions. + # This if clause is specially used for it. + info["target"] = "".join(info["target"].split()) + else: + info["target"] = "".join(info["target"]) + + if isinstance(info["target"], list) and len(info["target"]) == 1: + info["target"] = info["target"][0] + + data_sample = { + "dataset": "agieval", + "split": "test", + "category": dataset_name, + "instruction": info["instruction"], + "input": "", + "output": "", + "target": info["target"], + } + + dataset["test"][dataset_name]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py new file mode 100644 index 000000000000..45b0151b849f --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -0,0 +1,24 @@ +from abc import abstractstaticmethod + +from colossal_eval.utils import jdump + + +class BaseDataset: + """ + Base class for dataset wrapper. + + Args: + path: The path to the original dataset. + logger: Logger for the dataset. + """ + + def __init__(self, path, logger, few_shot): + self.dataset = self.load(path, logger, few_shot) + + def save(self, save_path): + """Save the converted dataset""" + jdump(self.dataset, save_path) + + @abstractstaticmethod + def load(path, logger): + """Load the original dataset and convert it into the inference dataset""" diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py new file mode 100644 index 000000000000..32ec52087bd3 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -0,0 +1,132 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +ceval_subject_mapping = { + "computer_network": ["Computer Network", "计算机网络", "STEM"], + "operating_system": ["Operating System", "操作系统", "STEM"], + "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"], + "college_programming": ["College Programming", "大学编程", "STEM"], + "college_physics": ["College Physics", "大学物理", "STEM"], + "college_chemistry": ["College Chemistry", "大学化学", "STEM"], + "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"], + "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"], + "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"], + "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"], + "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"], + "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"], + "high_school_physics": ["High School Physics", "高中物理", "STEM"], + "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"], + "high_school_biology": ["High School Biology", "高中生物", "STEM"], + "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"], + "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"], + "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"], + "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"], + "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"], + "college_economics": ["College Economics", "大学经济学", "Social Science"], + "business_administration": ["Business Administration", "工商管理", "Social Science"], + "marxism": ["Marxism", "马克思主义基本原理", "Social Science"], + "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"], + "education_science": ["Education Science", "教育学", "Social Science"], + "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"], + "high_school_politics": ["High School Politics", "高中政治", "Social Science"], + "high_school_geography": ["High School Geography", "高中地理", "Social Science"], + "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"], + "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"], + "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"], + "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"], + "logic": ["Logic", "逻辑学", "Humanities"], + "law": ["Law", "法学", "Humanities"], + "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"], + "art_studies": ["Art Studies", "艺术学", "Humanities"], + "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"], + "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"], + "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"], + "high_school_history": ["High School History", "高中历史", "Humanities"], + "middle_school_history": ["Middle School History", "初中历史", "Humanities"], + "civil_servant": ["Civil Servant", "公务员", "Other"], + "sports_science": ["Sports Science", "体育学", "Other"], + "plant_protection": ["Plant Protection", "植物保护", "Other"], + "basic_medicine": ["Basic Medicine", "基础医学", "Other"], + "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"], + "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], + "accountant": ["Accountant", "注册会计师", "Other"], + "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], + "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], + "tax_accountant": ["Tax Accountant", "税务师", "Other"], + "physician": ["Physician", "医师资格", "Other"], +} + +default_inference_kwargs = { + "calculate_loss": False, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class CEvalDataset(BaseDataset): + """ + Dataset class for CEval dataset. + Data source: https://huggingface.co/datasets/ceval/ceval-exam + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(f"_{split}.csv")] + subject = ceval_subject_mapping[subject][1] + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": []} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + _ = next(reader) + for row in reader: + # Dev split have answer and explanation so len(row) is 8 + # But test split doesn't contain answer and explanation, so len(row) is 6 + assert len(row) >= 6 + choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" + data_sample = { + "dataset": "ceval", + "split": split, + "category": subject, + "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。", + "input": f"题目:{row[1]}\n{choices}\n答案:", + "output": "", + "target": row[6] if split == "dev" else "", + "id": int(row[0]), + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py new file mode 100644 index 000000000000..51f8ca14e0c8 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -0,0 +1,144 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +cmmlu_subject_mapping = { + "agronomy": "农学", + "anatomy": "解剖学", + "ancient_chinese": "古汉语", + "arts": "艺术学", + "astronomy": "天文学", + "business_ethics": "商业伦理", + "chinese_civil_service_exam": "中国公务员考试", + "chinese_driving_rule": "中国驾驶规则", + "chinese_food_culture": "中国饮食文化", + "chinese_foreign_policy": "中国外交政策", + "chinese_history": "中国历史", + "chinese_literature": "中国文学", + "chinese_teacher_qualification": "中国教师资格", + "clinical_knowledge": "临床知识", + "college_actuarial_science": "大学精算学", + "college_education": "大学教育学", + "college_engineering_hydrology": "大学工程水文学", + "college_law": "大学法律", + "college_mathematics": "大学数学", + "college_medical_statistics": "大学医学统计", + "college_medicine": "大学医学", + "computer_science": "计算机科学", + "computer_security": "计算机安全", + "conceptual_physics": "概念物理学", + "construction_project_management": "建设工程管理", + "economics": "经济学", + "education": "教育学", + "electrical_engineering": "电气工程", + "elementary_chinese": "小学语文", + "elementary_commonsense": "小学常识", + "elementary_information_and_technology": "小学信息技术", + "elementary_mathematics": "初等数学", + "ethnology": "民族学", + "food_science": "食品科学", + "genetics": "遗传学", + "global_facts": "全球事实", + "high_school_biology": "高中生物", + "high_school_chemistry": "高中化学", + "high_school_geography": "高中地理", + "high_school_mathematics": "高中数学", + "high_school_physics": "高中物理学", + "high_school_politics": "高中政治", + "human_sexuality": "人类性行为", + "international_law": "国际法学", + "journalism": "新闻学", + "jurisprudence": "法理学", + "legal_and_moral_basis": "法律与道德基础", + "logical": "逻辑学", + "machine_learning": "机器学习", + "management": "管理学", + "marketing": "市场营销", + "marxist_theory": "马克思主义理论", + "modern_chinese": "现代汉语", + "nutrition": "营养学", + "philosophy": "哲学", + "professional_accounting": "专业会计", + "professional_law": "专业法学", + "professional_medicine": "专业医学", + "professional_psychology": "专业心理学", + "public_relations": "公共关系", + "security_study": "安全研究", + "sociology": "社会学", + "sports_science": "体育学", + "traditional_chinese_medicine": "中医中药", + "virology": "病毒学", + "world_history": "世界历史", + "world_religions": "世界宗教", +} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class CMMLUDataset(BaseDataset): + """ + Dataset class for CMMLU dataset. + Data source: https://github.com/haonan-li/CMMLU/tree/master/data + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(".csv")] + subject = cmmlu_subject_mapping[subject] + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": []} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + _ = next(reader) + for row in reader: + assert len(row) == 7 + choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" + data_sample = { + "dataset": "cmmlu", + "split": split, + "category": subject, + "instruction": f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。", + "input": f"题目:{row[1]}\n{choices}\n答案:", + "output": "", + "target": row[6], + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py new file mode 100644 index 000000000000..54ea478ae5d6 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -0,0 +1,70 @@ +from collections import defaultdict +from copy import deepcopy +from typing import Dict, List + +from colossal_eval.utils import jload + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +default_inference_kwargs = { + "calculate_loss": False, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 256, +} + +# You can add your own subcategory questions and specify whether it is a single-choice question or has target answers and need to calculate loss. +single_choice_question = set() +calculate_loss = set() + + +def get_data_per_category(data): + data_per_category = defaultdict(list) + for item in data: + category = item["category"] + data_per_category[category].append(item) + + return data_per_category + + +class ColossalDataset(BaseDataset): + """ + Dataset class for Colossal dataset. + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + data = jload(path) + data_per_category = get_data_per_category(data) + categories = list(data_per_category.keys()) + + for category in categories: + dataset["test"][category] = {"data": []} + category_data = data_per_category[category] + + dataset["test"][category]["inference_kwargs"] = deepcopy(default_inference_kwargs) + + if category in calculate_loss: + dataset["test"][category]["inference_kwargs"]["calculate_loss"] = True + if category in single_choice_question: + dataset["test"][category]["inference_kwargs"]["all_classes"] = ["A", "B", "C", "D"] + + for item in category_data: + data_sample = { + "dataset": "colossal", + "split": "test", + "category": category, + "instruction": item["instruction"], + "input": item["input"], + "output": "", + "target": item["target"], + "id": item["id"], + } + dataset["test"][category]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py new file mode 100644 index 000000000000..7bf0639e4882 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -0,0 +1,122 @@ +import json +import os +import re +from copy import deepcopy +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +multi_choice_datasets = [ + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "English Fill in Blanks", + "English Reading Comp", + "Geography MCQs", + "Physics MCQs", + "English Cloze Test", +] + +chinese_qa_datasets = [ + "Biology MCQs", + "Chemistry MCQs", + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "Geography MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Physics MCQs", + "Political Science MCQs", +] +english_qa_datasets = ["English MCQs", "English Fill in Blanks", "English Reading Comp", "English Cloze Test"] + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_all_classes(instruction: str): + letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + pattern = r"([A-Z]\. |[A-Z].|[A-Z]\.)" + options = sorted(list(set(re.findall(pattern, instruction)))) + options = sorted(list(set([string[0] for string in options]))) + + for i in range(len(options)): + if options[i] == letters[i]: + continue + else: + return options[0:i] + return options + + +class GaoKaoBenchDataset(BaseDataset): + """ + Dataset class for GAOKAO-Bench dataset. + Data source: https://github.com/OpenLMLab/GAOKAO-Bench/tree/main/data + This dataset class will convert the original dataset into the inference dataset. + + A few typos needed to be manually corrected in the origin dataset, some of the following is fixed. + Issue link: https://github.com/OpenLMLab/GAOKAO-Bench/issues/20 + 1. Option C missing in index 111 in 2010-2022_Chemistry_MCQs.json + 2. Option B missing "." after it in index 16 in 2012-2022_English_Cloze_Test.json + 3. Option G missing "." after it in index 23 in 2012-2022_English_Cloze_Test.json + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: + files = os.listdir(os.path.join(path, "data", category)) + files.sort() + + for file in files: + subject = file[10:-5].split("_") + subject = " ".join(subject) + dataset["test"][subject] = {"data": []} + + file_dir = os.path.join(path, "data", category, file) + + with open(file_dir, encoding="utf-8") as f: + data = json.load(f) + + # It's been tested that each data sample in one subcategory have same inference arguments. + inference_kwargs = deepcopy(default_inference_kwargs) + if category == "Multiple-choice_Questions" and subject not in multi_choice_datasets: + all_classes = get_all_classes(data["example"][0]["question"]) + inference_kwargs["all_classes"] = all_classes + if subject in english_qa_datasets: + inference_kwargs["language"] = "English" + if subject in chinese_qa_datasets: + inference_kwargs["language"] = "Chinese" + + dataset["test"][subject]["inference_kwargs"] = inference_kwargs + + for sample in data["example"]: + # Convert multi-choice answers to a single string. + # We will convert it back when evaluating. + # We do this because if target is a list, it should be only used for multiple target answers. + if subject in multi_choice_datasets: + sample["answer"] = "".join(sample["answer"]) + + if isinstance(sample["answer"], list) and len(sample["answer"]) == 1: + sample["answer"] = sample["answer"][0] + + data_sample = { + "dataset": "gaokaobench", + "split": "test", + "category": f"{category[:-10]}-{subject}", + "instruction": sample["question"].strip() + "\n答案:", + "input": "", + "output": "", + "target": sample["answer"], + } + + dataset["test"][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py new file mode 100644 index 000000000000..9ea5e3c7d77f --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -0,0 +1,120 @@ +import os +from copy import deepcopy +from typing import Dict, List + +from colossal_eval.utils import get_json_list + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +dataset2prompt = { + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:', + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ', + "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:', + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", +} + +dataset2maxlen = { + "narrativeqa": 128, + "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, + "gov_report": 512, + "qmsum": 512, + "multi_news": 512, + "vcsum": 512, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, + "lcc": 64, + "repobench-p": 64, +} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +class LongBenchDataset(BaseDataset): + """ + Dataset class for LongBench dataset. + Data source: https://huggingface.co/datasets/THUDM/LongBench + This dataset class will convert the original dataset into the inference dataset. + + Issue link: https://github.com/THUDM/LongBench/issues/15 (fixed) + There are duplicate target answers in `nq.jsonl`, but this doesn't affect evaluation results. + Also doesn't affect perplexity calculation (the program only need to select the minimum loss). + """ + + @staticmethod + def load(path: str, logger: DistributedLogger) -> List[Dict]: + dataset = {"test": {}} + + files = os.listdir(path) + files.sort() + + for file in files: + category = file[0:-6] + + if category.endswith("_e"): + continue + + dataset["test"][category] = {"data": []} + + file_dir = os.path.join(path, file) + + loaded_jsonl = get_json_list(file_dir) + + # It's been tested that each data sample in one subcategory have same inference arguments. + inference_kwargs = deepcopy(default_inference_kwargs) + if loaded_jsonl[0]["all_classes"] is not None: + inference_kwargs["all_classes"] = loaded_jsonl[0]["all_classes"] + inference_kwargs["max_new_tokens"] = dataset2maxlen[category] + dataset["test"][category]["inference_kwargs"] = inference_kwargs + + for sample in loaded_jsonl: + prompt = dataset2prompt[category].format(**sample) + + data_sample = { + "dataset": "longbench", + "split": "test", + "category": category, + "instruction": prompt, + "input": "", + "output": "", + "target": sample["answers"], + } + + dataset["test"][category]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py new file mode 100644 index 000000000000..b89c0a13cff1 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -0,0 +1,73 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "English", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class MMLUDataset(BaseDataset): + """ + Dataset class for MMLU dataset. + Data source: https://github.com/hendrycks/test + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(f"_{split}.csv")].split("_") + subject = " ".join([word.title() if word != "us" else "US" for word in subject]) + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": [], "inference_kwargs": {}} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + assert len(row) == 6 + choices = f"A. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}" + data_sample = { + "dataset": "mmlu", + "split": split, + "category": subject, + "instruction": f"The following is a single-choice question on {subject}. Answer the question by replying A, B, C or D.", + "input": f"Question: {row[0]}\n{choices}\nAnswer: ", + "output": "", + "target": row[5], + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md new file mode 100644 index 000000000000..37fbda4c8647 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md @@ -0,0 +1,248 @@ +# GPT Evaluation +## Table of Contents +- [Overview](#overview) +- [GPT Evaluation](#gpt-evaluation) + - [Evaluation Category](#evaluation-category) + - [Evaluation Category Examples](#evaluation-category-examples) + - [Evaluation Metrics](#evaluation-metrics) +- [Evaluation Process](#evaluation-process) + - [Data Format](#data-format) + - [Prompt](#prompt) + - [Battle Prompt](#battle-prompt) + - [Evaluation Prompt](#evaluation-prompt) + - [Evaluation](#evaluation) + - [Configuration](#configuration) + - [Evaluate](#evaluate) +- [FAQ](#faq) +- [Citations](#citations) + + +## Overview + +In this directory, we introduce how you can evaluate your model using GPTs. It is now available for evaluation of both Chinese and English capability and we provide the following functions: + +* Compare the performance of two different models (battle). +* Rate the model according to pre-defined metrics using prompting design. +* Rate the model according to pre-defined metrics with additional reference answer using prompting design. + +## GPT Evaluation + +### Evaluation Category + +Our evaluation pipeline can examine the model's capability using different categories of questions. The following table includes some example categories. You can add your own questions. + +| Evaluation Category | Description | +| :-----------------: | :----------------------------------------------------------- | +| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. | +| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. | +| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. | +| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. | +| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. | + + +### Evaluation Category Examples +To better understand each evaluation category, here are some example questions provided. Example questions are in the `configs/gpt_evaluation/data` folder. + + +| Evaluation Category | Chinese Example | English Example | +| :-----------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | +| Brainstorming | 列举一些可以促进头发生长的食物。 | How do you properly chop an onion without crying? | +| Chat | 基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
老李:你好,小张,我很乐意帮助你。你想问些什么?
小张:我想知道如何确定鸡的品种和性别?
老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
小张:
| Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.
Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice?
Emma: Hi Alex, sure. What kind of writing are you doing?
Alex: I'm trying to write a novel, but I just can't seem to find any inspiration.
Emma:
| +| Generation | 请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。 | Write a set of guidelines for first-time pet owners on how to properly care for a new puppy. | +| Open QA | 解释什么是RNA病毒和DNA病毒。 | Explain the process of osmosis in biological systems. | +| Roleplay | 我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}” | I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is "I need a rap song about finding strength within yourself." | + +### Evaluation Metrics + +GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 10 pre-defined evaluation metrics both in Chinese and English: + +| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) | +| :-------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | +| 语言组织
(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。

Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。

1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
3. Determine if the answer is relevant to the question or topic and conveys a clear message.
4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. | +| 切题
(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。

Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
2. 阅读答案,确认答案是否直接回答了题目所问的问题。
3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。

1. Read the question to determine what the question asks and what aspects of the question need to be answered.
2. Read the answers to make sure that they directly answer the question asked.
3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. | +| 创意性
(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。

Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。

1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. | +| 实用性
(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。

Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。

1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. | +| 正确性
(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。

Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。

1. Read the question carefully and try to answer the question yourself.
2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. | +| 自然
(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。

Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
2. 检查答案内容是否符合题目给定的身份。
3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。

1. Read the question and determine the identity information provided in the question.
2. Check whether the content of the answer matches the identity given in the question.
3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. | +| 参与感
(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。

Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。

1. Read the questions to determine the context and background of the dialogue.
2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. | +| 合理性
(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。

Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。

1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. | +| 多样性
(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。

Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。

1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
3. Check the creativity and imagination of the response to see if the response is engaging to read on.
4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. | +| 保真度
(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。

Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
阅读题目的请求,确认回答请求时需要注意的细节。
3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。

1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
2. Read the question's request and confirm the details that need to be taken into account when answering the request.
3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. | + +GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5. + +> **NOTE 1:** You can find all the prompt words and CoT(Chain-of-Thought) in `configs/gpt_evaluation/prompt/evaluation_prompt`. + +> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq). + +## Evaluation Process + +### Data Format + +A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question. +An element should have the following fields: + +* `category` (str, compulsory): The category of the instruction / question. +* `instruction` (str, compulsory): The instruction / question for the LLM. +* `input` (str, optional): The additional context of the instruction / question. +* `output` (str, optional): The model output of the instruction, models will fill in this field during inference time. +* `target` (str, optional): The target answer for the instruction. +* `id` (int, compulsory): The ID of the instruction / question. + +Example: + +```json +[ + { + "category": "brainstorming", + "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。", + "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:", + "output": "", + "target": "", + "id": 2 + } +] +``` + +### Prompt + +#### Battle Prompt + +The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `configs/gpt_evaluation/prompt/battle_prompt`. + +```json +{ + "id": 1, + "system_prompt": "你是一个检查回答质量的好助手。", + "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n", + "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。" +} +``` + +#### Evaluation Prompt + +The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `configs/gpt_evaluation/prompt/evaluation_prompt`. + +```json +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:" + }, + "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + } +} +``` + +`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file. + +`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`. + +### Evaluation + +#### Configuration + +The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics in key `GPT`. You can find an example English config file in `configs/gpt_evaluation/config/config_en.json`. + +```json +{ + "language": "cn", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + } + } +} +``` + +`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now. + +`"category"`: the category/categories needed to evaluate the model capability. + +`"GPT"`: the metrics you want to use for GPT evaluation. + + +#### Evaluate + +After setting the configuration file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models. + +An example script is provided as follows: + +```shell +python eval.py \ + --config_file "path to the config file" \ + --battle_prompt_file "path to the prompt file for battle" \ + --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ + --target_file "path to the target answer file" \ + --answer_file_list "path to the answer files of at most 2 models" \ + --model_name_list "the names of at most 2 models" \ + --gpt_model "which GPT model to use for evaluation" \ + --save_path "path to save results" \ + --openai_key "your openai key" \ +``` + +If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`, but make sure the reference file have target answers. + +## FAQ + +
How can I add a new GPT evaluation metric? + +For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric. + +```json +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness" + }, + "CoT": { + "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + } +} +``` + +
+ +## Citations + +```bibtex +@misc{vicuna2023, + title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality}, + url = {https://vicuna.lmsys.org}, + author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.}, + month = {March}, + year = {2023} +} + +@misc{liu2023geval, + title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment}, + author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu}, + year={2023}, + eprint={2303.16634}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/applications/ColossalEval/colossal_eval/evaluate/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py new file mode 100644 index 000000000000..3c5df09a6909 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py @@ -0,0 +1,3 @@ +from .dataset_evaluator import DatasetEvaluator + +__all__ = ["DatasetEvaluator"] diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py new file mode 100644 index 000000000000..c70988707a15 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py @@ -0,0 +1,269 @@ +from typing import Dict, List + +import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper +import numpy as np +import tqdm + +LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"] +LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"] +CombinedMetrics = ["combined_single_choice_accuracy"] +OtherMetrics = [ + "f1_score", + "f1_zh_score", + "rouge_score", + "rouge_zh_score", + "retrieval_score", + "retrieval_zh_score", + "classification_score", + "code_sim_score", + "count_score", + "multi_choice_accuracy", + "math_equivalence", + "single_choice_accuracy", +] + + +class DatasetEvaluator(object): + """ + Dataset evaluator. + + """ + + def __init__(self): + pass + + def _calculate_label_metrics(self, metric: str, category: str): + """Calculate label-based metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + str_label_map = { + choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"]) + } + + references = [str_label_map[sample["target"]] for sample in self.data[category]["data"]] + [sample["output"] for sample in self.data[category]["data"]] + + flag = False + softmaxs = [] + for i, sample in enumerate(self.data[category]["data"]): + if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))): + if not flag: + print( + f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}." + ) + flag = True + score = 0 + for ref in sample["target"]: + score = max( + score, + metric_helper.single_choice_accuracy( + sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"] + ), + ) + softmaxs.append(references[i] if score == 1 else -1) + else: + softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values())))) + + references = np.array(references) + softmaxs = np.array(softmaxs) + scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100 + + self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += scores * weight + + def _calculate_combined_metrics(self, metric: str, category: str): + """Calculate combined metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + references = [sample["target"] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + str_label_map = { + choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"]) + } + + references_labels = [str_label_map[sample["target"][0]] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + flag = False + softmaxs = [] + for i, sample in enumerate(self.data[category]["data"]): + if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))): + if not flag: + print( + f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}." + ) + flag = True + score = 0 + for ref in sample["target"]: + score = max( + score, + metric_helper.single_choice_accuracy( + sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"] + ), + ) + softmaxs.append(references[i] if score == 1 else -1) + else: + softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values())))) + + metric_method = eval("metric_helper." + metric) + + total_score = 0.0 + for prediction, reference, references_label, softmax in zip( + predictions, references, references_labels, softmaxs + ): + score = 0.0 + + for ref in reference: + score = max( + score, + metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]), + ) + if references_label == softmax: + score = 1 + + total_score += score + total_score = total_score * 100 / len(self.data[category]["data"]) + + self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += total_score * weight + + def _calculate_other_metrics(self, metric: str, category: str): + """Calculate other metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + references = [sample["target"] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + metric_method = eval("metric_helper." + metric) + + total_score = 0.0 + for prediction, reference in zip(predictions, references): + score = 0.0 + for ref in reference: + score = max( + score, + metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]), + ) + total_score += score + total_score = total_score * 100 / len(predictions) + + self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += total_score * weight + + def _calculate_loss_metrics(self, metric: str, category: str): + """Calculate perplexity.""" + if metric == "perplexity": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss"]) for sample in self.data[category]["data"]] + perplexity = np.mean(np.exp(np.array(losses))) + + self.evaluation_results["perplexity"][category] = (perplexity, len(self.data[category]["data"])) + self.evaluation_results["perplexity"]["ALL"] += perplexity * weight + elif metric == "ppl_score": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss"]) for sample in self.data[category]["data"]] + perplexity_score = np.mean(np.exp(-np.array(losses))) * 100 + + self.evaluation_results["ppl_score"][category] = (perplexity_score, len(self.data[category]["data"])) + self.evaluation_results["ppl_score"]["ALL"] += perplexity_score * weight + elif metric == "ppl_score_over_choices" and self.data[category]["inference_kwargs"]["all_classes"] is not None: + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + loss_over_choices = [sample["loss_over_choices"] for sample in self.data[category]["data"]] + perplexity_score_over_choices = np.mean(np.exp(-np.array(loss_over_choices))) * 100 + + self.evaluation_results["ppl_score_over_choices"][category] = ( + perplexity_score_over_choices, + len(self.data[category]["data"]), + ) + self.evaluation_results["ppl_score_over_choices"]["ALL"] += perplexity_score_over_choices * weight + elif metric == "per_byte_perplexity": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]] + perplexity = np.mean(np.exp(np.array(losses) / np.array(self.N_bytes[category]))) + + self.evaluation_results["per_byte_perplexity"][category] = perplexity + self.evaluation_results["per_byte_perplexity"]["ALL"] += perplexity * weight + elif metric == "per_byte_ppl_score": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]] + perplexity_score = np.mean(np.exp(-np.array(losses) / np.array(self.N_bytes[category]))) * 100 + + self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score + self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight + + def _evaluate(self): + """Calculate and return evaluation results""" + + for metric in self.metrics: + pbar = tqdm.tqdm( + desc=f"{self.dataset_name}-{metric}-{self.model_name}", total=len(self.suggested_categories[metric]) + ) + + if metric in LabelBasedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_label_metrics(metric, category) + pbar.update(1) + elif metric in LossBasedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_loss_metrics(metric, category) + pbar.update(1) + elif metric in CombinedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_combined_metrics(metric, category) + pbar.update(1) + elif metric in OtherMetrics: + for category in self.suggested_categories[metric]: + self._calculate_other_metrics(metric, category) + pbar.update(1) + + return self.evaluation_results + + def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]): + """ + Evaluate inference data on the given metrics. + + Args: + data: Data to be evaluated. + dataset_name: Name of the dataset + model_name: Name of the model + metrics: Metrics used to evaluate. + + """ + self.data = data + self.dataset_name = dataset_name + self.model_name = model_name + self.categories = list(data.keys()) + self.metrics = metrics + + self.evaluation_results = { + metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics + } + + self.total_length = 0 + self.total_single_choices = 0 + for value in self.data.values(): + self.total_length += len(value["data"]) + if value["inference_kwargs"]["all_classes"] is not None: + self.total_single_choices += len(value["data"]) + + self.metric_total_length = {metric: 0 for metric in self.metrics} + self.suggested_categories = {metric: [] for metric in self.metrics} + + for metric in self.metrics: + self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric] + if "ALL" in self.suggested_categories[metric]: + self.suggested_categories[metric] = self.categories + self.metric_total_length[metric] = self.total_length + continue + for category in self.suggested_categories[metric]: + self.metric_total_length[metric] += len(self.data[category]["data"]) + + if "per_byte_perplexity" in self.metrics or "per_byte_ppl_score" in self.metrics: + self.N_bytes = {category: [] for category in self.categories} + for category in self.categories: + samples = self.data[category]["data"] + for sample in samples: + self.N_bytes[category].append(sample["byte_num"][0]) + + return self._evaluate() diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py new file mode 100644 index 000000000000..914465478dec --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py @@ -0,0 +1,623 @@ +# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py +# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py +# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py + +import difflib +import re +import string +from collections import Counter + +import jieba +from fuzzywuzzy import fuzz +from rouge import Rouge + +metrics4subcategory = { + "pretrain": { + "perplexity": ["ALL"], + "ppl_score": ["ALL"], + "per_byte_perplexity": ["ALL"], + "per_byte_ppl_score": ["ALL"], + }, + # The commented are non 4-choice questions. + "agieval": { + "combined_single_choice_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "first_token_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "single_choice_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"], + "math_equivalence": ["gaokao-mathcloze", "math"], + "perplexity": ["ALL"], + "ppl_score_over_choices": [ + "lsat-ar", + "lsat-lr", + "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "jec-qa-kd", + "jec-qa-ca", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + "gaokao-physics", + "gaokao-mathqa", + ], + "ppl_score": ["ALL"], + }, + "cmmlu": { + "first_token_accuracy": ["ALL"], + "single_choice_accuracy": ["ALL"], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, + "gaokaobench": { + "combined_single_choice_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "first_token_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "single_choice_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "multi_choice_accuracy": [ + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "English Fill in Blanks", + "English Reading Comp", + "Geography MCQs", + "Physics MCQs", + "English Cloze Test", + ], + "math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"], + "rouge_score": ["English Language Cloze Passage"], + "rouge_zh_score": [ + "Chinese Language Famous Passages and Sentences Dictation", + "Chemistry Open-ended Questions", + "History Open-ended Questions", + "Biology Open-ended Questions", + "Political Science Open-ended Questions", + "English Language Error Correction", + "Chinese Language Language and Writing Skills Open-ended Questions", + "Math II Open-ended Questions", + "Chinese Language Literary Text Reading", + "Chinese Language Ancient Poetry Reading", + "Chinese Language Classical Chinese Reading", + "Physics Open-ended Questions", + "Math I Open-ended Questions", + "Geography Open-ended Questions", + "Chinese Language Practical Text Reading", + ], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, + "longbench": { + "f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"], + "f1_zh_score": ["multifieldqa_zh"], + "rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"], + "rouge_zh_score": ["dureader", "vcsum"], + "retrieval_score": ["passage_retrieval_en"], + "retrieval_zh_score": ["passage_retrieval_zh"], + "classification_score": ["trec", "lsht"], + "code_sim_score": ["lcc", "repobench-p"], + "count_score": ["passage_count"], + "perplexity": ["ALL"], + "ppl_score": ["ALL"], + }, + "mmlu": { + "first_token_accuracy": ["ALL"], + "single_choice_accuracy": ["ALL"], + "accuracy": ["ALL"], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, +} + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace("\n", "") + # print(string) + + # remove inverse spaces + string = string.replace("\\!", "") + # print(string) + + # replace \\ with \ + string = string.replace("\\\\", "\\") + # print(string) + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + # print(string) + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + # print(string) + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def parse_math_answer(raw_string): + def remove_boxed(s): + left = "\\boxed{" + try: + assert s[: len(left)] == left + assert s[-1] == "}" + answer = s[len(left) : -1] + if "=" in answer: + answer = answer.split("=")[-1].lstrip(" ") + return answer + except: + return None + + def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + def get_answer_with_dollar_sign(s): + first_pattern = "\$(.*)\$" + last_match = None + matches = re.findall(first_pattern, s) + if matches: + last_match = matches[-1] + if "=" in last_match: + last_match = last_match.split("=")[-1].lstrip(" ") + return last_match + + def get_answer_without_dollar_sign(s): + last_match = None + if "=" in s: + last_match = s.split("=")[-1].lstrip(" ").rstrip(".") + if "\\n" in last_match: + last_match = last_match.split("\\n")[0] + else: + pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])" + matches = re.findall(pattern, s) + if matches: + last_match = matches[-1] + return last_match + + if "\\boxed" in raw_string: + answer = remove_boxed(last_boxed_only_string(raw_string)) + else: + answer = get_answer_with_dollar_sign(raw_string) + if not answer: + answer = get_answer_without_dollar_sign(raw_string) + return answer + + +def math_equivalence(prediction, reference, **kwargs): + prediction = parse_math_answer(prediction) + + if prediction is None and reference is None: + print("WARNING: Both None") + return False + + if prediction is None or reference is None: + return False + + try: + ss1 = _strip_string(prediction) + ss2 = _strip_string(reference) + return ss1 == ss2 + except: + return prediction == reference + + +def multi_choice_accuracy(prediction, reference, **kwargs): + # Only find uppercase letters not surrounded by lowercase letters + all_classes = kwargs.get("all_classes", None) + if all_classes: + pattern = f"(? highest_similarity: + highest_similarity = similarity + best_match = string + score = float(best_match == reference) + return score + + +def rouge_score(prediction, reference, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [reference], avg=True) + except: + return 0.0 + return scores["rouge-l"]["f"] + + +def rouge_zh_score(prediction, reference, **kwargs): + prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) + reference = " ".join(list(jieba.cut(reference, cut_all=False))) + score = rouge_score(prediction, reference) + return score + + +def _f1_score(prediction, reference, **kwargs): + common = Counter(prediction) & Counter(reference) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(reference) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def f1_score(prediction, reference, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(reference) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return _f1_score(prediction_tokens, ground_truth_tokens) + + +def f1_zh_score(prediction, reference, **kwargs): + prediction_tokens = list(jieba.cut(prediction, cut_all=False)) + ground_truth_tokens = list(jieba.cut(reference, cut_all=False)) + prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] + ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] + prediction_tokens = [token for token in prediction_tokens if len(token) > 0] + ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] + return _f1_score(prediction_tokens, ground_truth_tokens) diff --git a/applications/ColossalEval/colossal_eval/evaluate/evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py new file mode 100644 index 000000000000..11e204b504c5 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py @@ -0,0 +1,110 @@ +import os +from typing import Any, Dict, List + +import colossal_eval.evaluate.gpt_evaluate as gpt_evaluate + +from .utils import get_data_per_category + + +class Evaluator(object): + """ + A class named Evaluator includes GPT-3.5/GPT-4 evaluation + + """ + + def __init__( + self, + params: Dict[str, Any], + battle_prompt: Dict[str, Any], + gpt_evaluation_prompt: Dict[str, Any], + gpt_model: str, + language: str, + gpt_with_reference: bool, + ) -> None: + self.params = params + self.battle_prompt = battle_prompt + self.gpt_evaluation_prompt = gpt_evaluation_prompt + self.gpt_model = gpt_model + self.language = language + self.gpt_with_reference = gpt_with_reference + self.gpt_evaluation_results = dict() + self.battle_results = [] + + def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None: + """ + Comparison between two models using GPT-4 as the reviewer. + """ + + self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt) + + def evaluate(self, answers: List[Dict], targets: List[Dict], save_path: str, model_name: str) -> None: + """ + A comprehensive evaluation of the answers from the model. + The function evaluates the model's performance from different perspectives + using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. + + The metrics will be decided by the config file. + + """ + + answers_per_category = get_data_per_category(answers, list(self.params.keys())) + targets_per_category = get_data_per_category(targets, list(self.params.keys())) + + # gpt evaluation + for category in self.params: + if len(answers_per_category[category]) == 0: + print(f"Category {category} specified in your config doesn't have corresponding answers!") + continue + + if self.params[category].get("GPT", None) is None: + continue + + category_metrics = self.params[category]["GPT"] + + prompt = self.gpt_evaluation_prompt.get(category, None) + if prompt is None: + print(f"No prompt for category {category}! Use prompt for category general now.") + prompt = self.gpt_evaluation_prompt["general"] + + self.gpt_evaluation_results[category] = gpt_evaluate.evaluate( + answers_per_category[category], + prompt, + category_metrics, + category, + save_path, + model_name, + self.gpt_model, + self.language, + references=targets_per_category[category] if self.gpt_with_reference else None, + ) + + def save(self, path: str, model_name_list: List[str]) -> None: + """ + Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. + + """ + + if len(model_name_list) == 2: + save_path = os.path.join(path, "gpt_evaluate", "battle_results") + gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path) + else: + if self.gpt_evaluation_results: + # Save evaluation results for GPT evaluation metrics. + gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") + gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") + + all_evaluations = gpt_evaluate.save_gpt_evaluation_results( + model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path + ) + + # Start to calculate scores and save statistics. + gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") + gpt_evaluate.save_gpt_evaluation_statistics( + model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path + ) + + # Save charts and csv. + gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") + gpt_evaluate.analyze_gpt_evaluation_statistics( + gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path + ) diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py similarity index 89% rename from applications/Chat/evaluate/gpt_evaluate.py rename to applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py index ad908f4ba48c..a0b1ed1143f0 100644 --- a/applications/Chat/evaluate/gpt_evaluate.py +++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py @@ -11,7 +11,7 @@ import pandas as pd import seaborn as sns import tqdm -from utils import jdump, jload +from colossal_eval.utils import jdump, jload ref_step_template = { "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", @@ -364,7 +364,7 @@ def get_gpt_evaluation_without_logprobs( """ Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer. - Temperature is set to 0 to make the model more deterministic. + Temprature is set to 0 to make the model more deterministic. Args: prompt: a dictionary including prompt template, CoT and metrics. @@ -401,7 +401,7 @@ def get_gpt_evaluation_without_logprobs( steps=prompt["CoT"][metric], ) - if prompt_reference: + if prompt_reference and (reference["target"] or reference["output"]): # Do a 2-round conversation response = multiturn_chat_completion( [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2 @@ -436,7 +436,7 @@ def get_gpt_evaluation_with_logprobs( Use completion model(text-davinci-003) to evaluate one model answer. Only completion models can return log probabilities. - Temperature is set to 0 to make the model more deterministic. + Temprature is set to 0 to make the model more deterministic. Args: prompt: a dictionary including prompt template, CoT and metrics. @@ -498,6 +498,8 @@ def evaluate( prompt: Dict[str, Any], metrics: List[str], category: str, + save_path: str, + model_name: str, model: str, language: str, references: List[Dict] = None, @@ -525,6 +527,72 @@ def evaluate( metrics_str = ", ".join(x for x in metrics) print(f"Category {category}'s metrics are {metrics_str}.") + gpt_base_save_path = os.path.join(save_path, "gpt_evaluate", "gpt_evaluate_results") + gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") + category_file = os.path.join(gpt_evaluation_results_save_path, model_name, f"{category}_evaluation_results.json") + + if os.path.exists(category_file): + print(f"Evaluation results for category {category}, model {model_name} already exists.") + print("Skip evaluating.") + + evaluations = jload(category_file) + + retry = [] + evaluations_copy = deepcopy(evaluations) + + success = [] + for idx, e in enumerate(evaluations_copy): + keys = list(e["evaluation"].keys()) + for key in keys: + if e["evaluation"][key] == {}: + retry.append(e["id"]) + print(f"Re-evaluate id {e['id']} now.") + break + if e["id"] not in retry: + success.append(e) + + if len(retry) == 0: + evaluations.sort(key=lambda x: x["id"]) + print(f"{category} done.") + return evaluations + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for idx, inst in enumerate(answers): + if not inst["id"] in retry: + continue + # Completion models can return log probabilities. + if model == "text-davinci-003": + future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) + else: + future = executor.submit( + get_gpt_evaluation_without_logprobs, + prompt, + inst, + metrics, + language, + reference=None if references is None else references[idx], + model=model, + max_tokens=1, + ) + + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), + desc=f"{category}: ", + total=len(futures), + ): + success.append(future.result()) + + success.sort(key=lambda x: x["id"]) + + print(f"Saving evaluation results for category {category}, model {model_name}.") + + jdump(success, category_file) + + return success + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: futures = [] for idx, inst in enumerate(answers): @@ -556,6 +624,10 @@ def evaluate( print(f"{category} done.") + print(f"Saving evaluation results for category {category}, model {model_name}.") + + jdump(evaluations, category_file) + return evaluations @@ -581,7 +653,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float: for key, value in logprobs.items(): # Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7". - # It is meaningless, and thus we don't calculate probability. + # It is meaningless and thus we don't calculate probability. if "bytes" in key: continue # results[0] is the score which corresponds to the key(predicted token). @@ -598,7 +670,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float: def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int: """ Calculate the score from the response returned by gpt-3.5-turbo or gpt-4. - Different from text-davinci-003, this function directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4. + Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4. Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo. Args: @@ -627,7 +699,7 @@ def save_gpt_evaluation_results( Args: model_name: name of the model for saving evaluation results. - gpt_evaluation_results: evaluations results for all the model answers. + gpt_evaluation_results: evaluations results for all of the model answers. save_path: path to save GPT evaluation statistics. """ @@ -647,7 +719,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav Args: model_name: name of the model for saving statistics. - evaluations: evaluations for all the model answers. + evaluations: evaluations for all of the model answers. save_path: path to save GPT evaluation statistics. """ @@ -669,7 +741,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav for evaluation in data: for metric in metrics: if evaluation["evaluation"][metric] == {}: - # This means after 3 retries, the server still returns an error, and we set the score to 0. + # This means after 3 retries, the server still returns an error and we set the score to 0. scores[metric].append(0) elif evaluation["evaluation"][metric]["logprobs"] is not None: scores[metric].append( diff --git a/applications/ColossalEval/colossal_eval/evaluate/utils.py b/applications/ColossalEval/colossal_eval/evaluate/utils.py new file mode 100644 index 000000000000..69fec46705ab --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/utils.py @@ -0,0 +1,8 @@ +def get_data_per_category(data, categories): + data_per_category = {category: [] for category in categories} + for item in data: + category = item["category"] + if category in categories: + data_per_category[category].append(item) + + return data_per_category diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py new file mode 100644 index 000000000000..8f6c9b414145 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseModel +from .chatglm import ChatGLM2Model, ChatGLMModel +from .huggingface import HuggingFaceCausalLM, HuggingFaceModel + +__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"] diff --git a/applications/ColossalEval/colossal_eval/models/base.py b/applications/ColossalEval/colossal_eval/models/base.py new file mode 100644 index 000000000000..aae796c1d56e --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/base.py @@ -0,0 +1,78 @@ +from abc import abstractclassmethod +from typing import Dict, List + +from colossal_eval.utils import Conversation, prompt_templates + +from colossalai.logging import DistributedLogger + + +class BaseModel: + """ + Base class for model wrapper. + + Args: + path: The path to the model. + model_max_length: The maximum sequence length of the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + ): + self.path = path + self.model_max_length = model_max_length + + if prompt_template: + self.prompt_template = prompt_template + else: + self.prompt_template = prompt_templates["plain"] + + self.batch_size = batch_size + self.logger = logger + + @abstractclassmethod + def inference(self, data: List[Dict]) -> None: + """ + Infer the given data. + This function will call self.generate() to get model outputs and also self.model(input) to get logits. + + Args: + data: The data for inference. + """ + + @abstractclassmethod + def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]: + """ + Generate results given a list of inputs. + + Args: + inputs: A list of strings. + max_new_tokens: The maximum length of the output. + + Returns: + A list of generated strings. + """ + + @abstractclassmethod + def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]: + """ + Get loss given batch and batch with target. + Use their length difference after tokenization to mask the loss and only compute loss at target tokens. + + Args: + batch: batch prompt without target answer. + batch_target: batch prompt with target answer. + + Returns: + A list of loss. + """ + + def to(self, device): + self.model.to(device) diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py new file mode 100644 index 000000000000..f293c4f699cd --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -0,0 +1,303 @@ +import copy +from typing import List + +import torch + +from .huggingface import HuggingFaceModel + +IGNORE_INDEX = -100 + + +class ChatGLMModel(HuggingFaceModel): + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + truncated_inputs = copy.deepcopy(inputs) + # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 + for i, input in enumerate(inputs): + a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False) + + if len(a_ids) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( + a_ids[-half:], skip_special_tokens=True + ) + truncated_inputs[i] = prompt + + return truncated_inputs + + @torch.no_grad() + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + ) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + labels_list = [] + input_ids_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 + # If there is no history, the prompt is just the query. + # We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B. + # See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276 + target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False) + + # Get prompt with length model_max_length - len(target_tokenized). + # Reserve some space for target answer tokens using max_new_tokens. + # This will generate the correct start_idx and end_idx. + max_new_tokens = len(target_tokenized) + + # Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens. + # See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323 + prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0] + input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False) + + input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized) + + context_length = input_ids.index(self.tokenizer.bos_token_id) + context_length - 1 + + target_ids = [IGNORE_INDEX] * len(input_ids) + + # -1 is for eos_token, we don't want to calculate loss on eos token. + target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] + + input_ids_list.append(torch.LongTensor(input_ids)) + labels_list.append(torch.LongTensor(target_ids)) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + start_indice += length + + return losses_per_sample, target_token_nums_per_sample, None + + def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]: + """ + Calculate loss only on target tokens. + Hugging Face generate() function can't return per sample loss. + It will only return the mean of the loss in a batch. + In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss. + + Args: + input_ids_list: A batch of input token ids. + labels: A batch of labels. + + Returns: + A list of loss. + + """ + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(torch.cuda.current_device()) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( + torch.cuda.current_device() + ) + + outputs = self.model(input_ids)[0] + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) + + lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy() + + loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy() + return loss_sum.tolist(), lens.tolist() + + +class ChatGLM2Model(ChatGLMModel): + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + truncated_inputs = copy.deepcopy(inputs) + # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 + for i, input in enumerate(inputs): + a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False) + + if len(a_ids) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( + a_ids[-half:], skip_special_tokens=True + ) + truncated_inputs[i] = prompt + + return truncated_inputs + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + # Follow the process of model.chat() method in modeling_chatglm2.py + # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020 + # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001 + + query = [] + for input in inputs: + prompt = self.tokenizer.build_prompt(input, None) + query.append(prompt) + + truncated_query = self._get_truncated_prompts(query, max_new_tokens) + + encoded_inputs = self.tokenizer( + truncated_query, + padding=True, + truncation=True, + return_tensors="pt", + max_length=self.model_max_length - max_new_tokens, + ).to(torch.cuda.current_device()) + + # Set output_scores=True to get prediction scores. + outputs = self.model.generate( + **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs + ) + + # We only need to decode predicted tokens. + sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :] + + scores = [] + if self.indices_for_choices: + # If the question is a single-choice question, we will return the scores of specific indices for first predicted token. + # The indices are the tokenization results of the options for the single-choice question. + # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D. + for option_indices in self.indices_for_choices: + scores.append(outputs.scores[0][:, option_indices].detach().cpu()) + + scores = torch.max(torch.stack(scores), dim=0)[0] + + decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return decoded_sequences, scores + + @torch.no_grad() + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + ) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + labels_list = [] + input_ids_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 + prompt = self.tokenizer.build_prompt(input, None) + + target_tokenized = self.tokenizer.encode( + text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length + ) + + max_new_tokens = len(target_tokenized) + prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] + input_tokenized = self.tokenizer.encode( + prompt_with_correct_length, + add_special_tokens=True, + truncation=True, + max_length=self.model_max_length, + ) + + input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id] + target_ids = [IGNORE_INDEX] * len(input_ids) + + # -1 is for "eos" + target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] + + input_ids_list.append(torch.LongTensor(input_ids)) + labels_list.append(torch.LongTensor(target_ids)) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + start_indice += length + + return losses_per_sample, target_token_nums_per_sample, None diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py new file mode 100644 index 000000000000..9f785a6aa9d1 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -0,0 +1,561 @@ +import copy +import math +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 +from peft import PeftModel +from tqdm import tqdm +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer + +from colossalai.logging import DistributedLogger + +from .base import BaseModel + +IGNORE_INDEX = -100 + + +class HuggingFaceModel(BaseModel): + """ + Model wrapper around HuggingFace AutoModel models. + + Args: + path: The path to a HuggingFace model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + peft_path: The name or path to the HuggingFace's PEFT model. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + peft_path: Optional[str] = None, + model_kwargs: Dict = None, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + ): + super().__init__( + path=path, + model_max_length=model_max_length, + prompt_template=prompt_template, + batch_size=batch_size, + logger=logger, + ) + self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs) + + self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path) + + def _get_choices_indices(self, language: str): + """ + Get indices for each choice + + Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2. + The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360. + print(self.tokenizer("答案:A")) to see + print(self.tokenizer("Answer: A")) to see + + """ + + # A trick for get "all" tokens ids related to given choices. + self.indices_for_choices = [[] for _ in range(2)] + for choice in self.choices: + self.indices_for_choices[0].append( + self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] + ) + self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) + + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): + """ + Load tokenizer. + + Args: + path: The path to the model. Usually it also serves as the path to the tokenizer. + tokenizer_path: The path to the tokenzier. + tokenizer_kwargs: Keyword arguments for the tokenizer. + + """ + + if self.batch_size > 1: + tokenizer_kwargs.update({"padding_side": "left"}) + tokenizer_kwargs.update({"truncation_side": "left"}) + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs) + + if self.tokenizer.pad_token_id is None: + self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") + if self.tokenizer.eos_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif self.tokenizer.eod_id: + # Qwen has an eod token "<|endoftext|>". + self.tokenizer.pad_token_id = self.tokenizer.eod_id + + def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + peft_path: The path to the peft model. + + """ + + if "torch_dtype" in model_kwargs: + model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"]) + + model_kwargs.setdefault("torch_dtype", torch.float16) + + self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + if peft_path is not None: + self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) + self.model.eval() + + def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]: + """ + Calculate loss only on target tokens. + Hugging Face generate() function can't return per sample loss. + It will only return the mean of the loss in a batch. + In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss. + + Args: + input_ids_list: A batch of input token ids. + labels: A batch of labels. + + Returns: + A list of loss. + + """ + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(torch.cuda.current_device()) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( + torch.cuda.current_device() + ) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device()) + + outputs = self.model(input_ids, attention_mask=attention_mask)[0] + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) + + lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy() + + loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy() + return loss_sum.tolist(), lens.tolist() + + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + """ + Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + https://github.com/THUDM/LongBench/blob/main/pred.py#L16 + + Args: + inputs: A batch of input prompts. + max_new_tokens: Max new tokens for model to generate. + + Returns: + Truncated prompts. + + """ + + truncated_inputs = copy.deepcopy(inputs) + for i, input in enumerate(inputs): + tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0] + if len(tokenized_prompt) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode( + tokenized_prompt[:half], skip_special_tokens=True + ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + truncated_inputs[i] = prompt + + return truncated_inputs + + def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]: + """ + Get input_ids and labels for pretrain data. + We only need batch_prompt because for pretain dataset, we don't need to predict new tokens. + + Args: + batch_prompt: A batch of prompt. + + Returns: + Input_ids and labels for the given batch. + + """ + input_ids_list = [] + labels_list = [] + bytes_list = [] + + for input in batch_prompt: + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. + ratio = [16, 8, 4, 2, 1] + tokenized = None + for r in ratio: + tokenized = self.tokenizer( + [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + if tokenized.input_ids.size(1) >= self.model_max_length: + break + + input_ids = copy.deepcopy(tokenized["input_ids"])[0] + target_ids = copy.deepcopy(input_ids) + + string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) + + bytes_list.append(len(string.encode("utf-8"))) + + input_ids_list.append(input_ids) + labels_list.append(target_ids) + + return input_ids_list, labels_list, bytes_list + + def _get_input_ids_and_labels( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool + ) -> Tuple[List[torch.LongTensor]]: + """ + Get input_ids and labels for the given data. + + Args: + batch_prompt: A batch of prompt. + batch_target: A batch of target. + + Returns: + Input_ids and labels for the given batch. + + """ + if pretrain: + return self._get_input_ids_and_labels_pretrain(batch_prompt) + + input_ids_list = [] + labels_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # TODO: Improve the labeling process. Should annotate the border by adding special tokens. + target_tokenized = self.tokenizer( + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + + # Get prompt with length model_max_length - len(target_tokenized). + # Reserve some space for target answer tokens using max_new_tokens. + # This will generate the correct start_idx and end_idx. + max_new_tokens = target_tokenized["input_ids"][0].size(0) + prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0] + input_tokenized = self.tokenizer( + [prompt_with_correct_length], + truncation=True, + max_length=self.model_max_length - max_new_tokens, + return_tensors="pt", + ) + + target_tokenized = self.tokenizer( + [prompt_with_correct_length + target], + truncation=True, + max_length=self.model_max_length, + return_tensors="pt", + ) + + start_idx = input_tokenized["input_ids"][0].size(0) + end_idx = target_tokenized["input_ids"][0].size(0) + + # Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1. + # This is caused by the different behavior of tokenizers. + # For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting. + # The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same. + # Baichuan: [29394, 31143, 31106] [29394, 31143, 703] + # Llama: [673, 29901, 29871] [673, 29901, 319] + # The length for sequence "prompt" and "prompt + A" is equal. + # For ChatGLM, the length of the tokenized sequences is different. + # ChatGLM: [16583, 12] [16583, 12, 167] + + if start_idx == end_idx: + start_idx -= 1 + + input_ids = copy.deepcopy(target_tokenized["input_ids"])[0] + target_ids = copy.deepcopy(input_ids) + + mask = torch.zeros_like(target_ids, dtype=torch.bool) + mask[start_idx:end_idx] = True + + target_ids[~mask] = IGNORE_INDEX + + input_ids_list.append(input_ids) + labels_list.append(target_ids) + + return input_ids_list, labels_list, None + + def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + """ + Infer the given data. + This function will call self.generate() to get model outputs and also self.model() to get logits. + + Args: + data: The data for inference. + inference_kwargs: Arguments for inference. + debug: Whether to display generated prompt for debugging. + + Returns: + Inference results. + + """ + calculate_loss = inference_kwargs["calculate_loss"] + classes = inference_kwargs["all_classes"] + language = inference_kwargs["language"] + pretrain = inference_kwargs["pretrain"] + max_new_tokens = inference_kwargs["max_new_tokens"] + few_shot_data = inference_kwargs.get("few_shot_data", None) + + # Some classification questions' options are texts not a single letter such as A, B, C and D. + # If the text length is greater than 1, we won't calculate loss over choices. + if classes is not None and any(len(c) > 1 for c in classes): + classes = None + + self.choices = classes + self.indices_for_choices = None + if self.choices: + # Get indices for each choice + self._get_choices_indices(language) + + self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} + + bar = tqdm( + range(math.ceil(len(data) / self.batch_size)), + desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps", + disable=not is_rank_0(), + ) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + answers = copy.deepcopy(data) + for i in range(0, len(data), self.batch_size): + batch = data[i : i + self.batch_size] + batch_prompt, batch_target = get_batch_prompt( + self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length + ) + + if is_rank_0() and debug and i == 0: + self.logger.info( + f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" + ) + self.logger.info("-" * 120) + self.logger.info("An example prompt and prompt with target is:") + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0]) + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0] + batch_target[0][0]) + + if not pretrain: + batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) + + if calculate_loss: + batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( + batch_prompt, batch_target, pretrain + ) + + probs = [] + if self.indices_for_choices: + scores = scores.to(torch.float32) + # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample. + # Otherwise this will violate the single-choice setting. + + if calculate_loss: + labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] + + loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() + + probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist() + probs = [ + {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) + ] + + for j in range(len(batch_prompt)): + if not pretrain: + answers[i + j]["output"] = batch_decodes[j].strip() + + if isinstance(scores, torch.Tensor): + answers[i + j]["softmax_over_choices"] = probs[j] + + if calculate_loss: + answers[i + j]["loss_over_choices"] = loss_over_choices[j] + + if calculate_loss: + answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + + # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. + # However, loss (which is per sample loss) suffices for most cases. + answers[i + j]["loss_sum"] = batch_losses[j] + answers[i + j]["token_num"] = batch_target_token_nums[j] + + if batch_bytes_nums: + answers[i + j]["byte_num"] = batch_bytes_nums[j] + + bar.update() + + return answers + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) + + encoded_inputs = self.tokenizer( + truncated_inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + max_length=self.model_max_length - max_new_tokens, + ).to(torch.cuda.current_device()) + + # Set output_scores=True to get prediction scores. + outputs = self.model.generate( + **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs + ) + + # We only need to decode predicted tokens. + sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :] + + scores = [] + if self.indices_for_choices: + # If the question is a single-choice question, we will return the scores of specific indices for first predicted token. + # The indices are the tokenization results of the options for the single-choice question. + # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D. + for option_indices in self.indices_for_choices: + scores.append(outputs.scores[0][:, option_indices].detach().cpu()) + + scores = torch.max(torch.stack(scores), dim=0)[0] + + decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return decoded_sequences, scores + + @torch.no_grad() + def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + if not pretrain: + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + bytes_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + + if bytes_list: + bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length]) + + start_indice += length + + if bytes_list: + return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample + + return losses_per_sample, target_token_nums_per_sample, None + + +class HuggingFaceCausalLM(HuggingFaceModel): + """ + Model wrapper around HuggingFace AutoModelForCausalLM models. + + Args: + path: The path to a HuggingFace model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + peft_path: The name or path to the HuggingFace's PEFT model. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + + """ + + def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + peft_path: The path to the peft model. + + """ + + if "torch_dtype" in model_kwargs: + model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"]) + + if "config" in model_kwargs: + model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"]) + + model_kwargs.setdefault("torch_dtype", torch.float16) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + if peft_path is not None: + self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) + self.model.eval() diff --git a/applications/ColossalEval/colossal_eval/utils/__init__.py b/applications/ColossalEval/colossal_eval/utils/__init__.py new file mode 100644 index 000000000000..d5ee6e13b747 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/__init__.py @@ -0,0 +1,4 @@ +from .conversation import Conversation, get_batch_prompt, prompt_templates +from .utilities import get_json_list, is_rank_0, jdump, jload + +__all__ = ["Conversation", "prompt_templates", "get_batch_prompt", "is_rank_0", "jload", "jdump", "get_json_list"] diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py new file mode 100644 index 000000000000..6c096a8523c0 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -0,0 +1,231 @@ +import dataclasses +from enum import Enum, auto +from typing import Dict, List, Optional, Tuple + +from transformers import AutoTokenizer + + +class SeparatorStyle(Enum): + ADD_BOS_EOS_TOKEN = auto() + ALPACA = auto() + PLAIN = auto() + + +@dataclasses.dataclass +class Conversation: + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.ADD_BOS_EOS_TOKEN + sep: str = "" + + def clear(self): + self.messages = [] + + def get_prompt(self): + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + self.sep + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ALPACA: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + ret = self.system + for role, message in self.messages: + if message: + ret += message + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_with_target(self, target): + prompt = self.get_prompt() + prompt_with_target = [] + + # Some dataset provides multiple target answers. + # This will make it difficult when we calculate loss. + # We convert target into list[str] first if the question only has one target answer. + target_answers = [] + if isinstance(target, str): + target_answers = [target] + else: + target_answers = target + + for target_answer in target_answers: + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + prompt_with_target.append(prompt + target_answer) + elif self.sep_style == SeparatorStyle.ALPACA: + prompt_with_target.append(prompt + target_answer) + elif self.sep_style == SeparatorStyle.PLAIN: + prompt_with_target.append(prompt + target_answer) + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return prompt_with_target + + def save_prompt(self): + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "\n" + else: + ret += role + ": " + "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep_style": self.sep_style, + "sep": self.sep, + } + + +def get_few_shot_prefix( + conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int +) -> str: + """ + Get few shot prefix. + + Args: + conv: Conversation template. + few_shot_examples: Few shot examples to generate few shot prompt prefix. + + Returns: + Few shot prompt prefix. + """ + + if language == "English": + few_shot_prefix = f"The following are answers for questions in an exam.\n\n" + elif language == "Chinese": + few_shot_prefix = f"以下是考试中各个问题的答案。\n\n" + + output = None + for i in range(len(few_shot_data)): + few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n" + + if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens: + output = few_shot_prefix + else: + break + + return output if output is not None else few_shot_prefix + + +def get_batch_prompt( + conv: Conversation, + batch: List[Dict], + few_shot_data: List[str], + tokenizer: Optional[AutoTokenizer], + language: Optional[str], + model_max_length: Optional[int], +) -> Tuple[List[Dict], List[Dict]]: + """ + Get batch prompt and target. + + Args: + conv: Conversation template. + batch: Batch data to generate prompt from. + few_shot_data: Few shot data to generate few shot prompt prefix. + + Returns: + Tuple containg batch prompt and target. + + """ + + batch_prompt = [] + batch_target = [] + + if isinstance(batch[0], dict): + for b in batch: + few_shot_prefix = "" + if few_shot_data is not None: + # For few-shot, only need input. Otherwise use instruction (in AGIEval). + query_text = b["input"] if b.get("input", "") != "" else b["instruction"] + + if isinstance(b["target"], str): + zero_shot_prompt = query_text + b["target"] + max_tokens = model_max_length - len(tokenizer([zero_shot_prompt]).input_ids[0]) + else: + raise Exception("When using few-shot, target answer should be a string.") + + few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) + else: + query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"] + + conv.append_message(conv.roles[0], few_shot_prefix + query_text) + conv.append_message(conv.roles[1], None) + + batch_prompt.append(conv.get_prompt()) + + target = b["target"] + if isinstance(b["target"], str): + target = [target] + + batch_target.append(target) + + conv.clear() + + return batch_prompt, batch_target + + +conv_coati = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, + sep="", +) + +conv_alpaca = Conversation( + system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ALPACA, + sep="\n\n", +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=[], + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain} diff --git a/applications/ColossalEval/colossal_eval/utils/utilities.py b/applications/ColossalEval/colossal_eval/utils/utilities.py new file mode 100644 index 000000000000..4eda07907495 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/utilities.py @@ -0,0 +1,62 @@ +import io +import json +import os + +import torch.distributed as dist + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def _make_w_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f_dirname = os.path.dirname(f) + if f_dirname != "": + os.makedirs(f_dirname, exist_ok=True) + f = open(f, mode=mode, encoding="utf-8") + return f + + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode, encoding="utf-8") + return f + + +def jdump(obj, f, mode="w", indent=4, default=str): + """ + Dump a str or dictionary to a file in json format. + + Args: + obj: An object to be written. + f: A string path to the location on disk. + mode: Mode for opening the file. + indent: Indent for storing json dictionaries. + default: A function to handle non-serializable entries; defaults to `str`. + + """ + f = _make_w_io_base(f, mode) + if isinstance(obj, (dict, list)): + json.dump(obj, f, indent=indent, default=default, ensure_ascii=False) + elif isinstance(obj, str): + f.write(obj) + else: + raise ValueError(f"Unexpected type: {type(obj)}") + f.close() + + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict + + +def get_json_list(file_path): + with open(file_path, "r") as f: + json_list = [] + for line in f: + json_list.append(json.loads(line if line != "null" else line)) + return json_list diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json new file mode 100644 index 000000000000..d7c864881008 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json @@ -0,0 +1,44 @@ +{ + "language": "cn", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json new file mode 100644 index 000000000000..6ebe3996b1cf --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json @@ -0,0 +1,44 @@ +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json new file mode 100644 index 000000000000..f869830555b4 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json @@ -0,0 +1,202 @@ +[ + { + "category": "brainstorming", + "instruction": "列举一些可以促进头发生长的食物。", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "brainstorming", + "instruction": "中年夫妻如何提升夫妻感情,请给出三个实用的的方法,并举例说明。", + "input": "", + "output": "", + "target": "", + "id": 2 + }, + { + "category": "brainstorming", + "instruction": "请列举4种日常的环保行为。", + "input": "", + "output": "", + "target": "", + "id": 3 + }, + { + "category": "brainstorming", + "instruction": "请给出5个可以随时随地锻炼身体的小动作。", + "input": "", + "output": "", + "target": "", + "id": 4 + }, + { + "category": "brainstorming", + "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?", + "input": "", + "output": "", + "target": "", + "id": 5 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。", + "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:", + "output": "", + "target": "", + "id": 6 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。李华是一名参加了期末考试的学生,他已经很担心自己的考试成绩。老师Lucy正在帮助他度过这个紧张的时刻。", + "input": "李华:Lucy老师,我很担心自己的考试成绩,我不知道我是否能够通过这次考试。 Lucy:放松,李华,你已经做好了充分的准备。相信你自己,你会做得很好的。 李华:我很怕考试时会忘记自己所学的知识。 Lucy:你可以预留一些时间,过一遍自己所学的知识点或笔记,这样你会更有信心和准确地回答考题。 李华:如果我还是失败了,该怎么办? Lucy:", + "output": "", + "target": "", + "id": 7 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。张先生是一名企业家,正在考虑是否开拓海外市场;李女士是一名跨境电商专家,擅长国际商务和电子商务。", + "input": "张先生:你好,李女士,我正在考虑将我们的产品销售扩大至海外市场,您有什么建议吗? 李女士:您好,张先生,我们需要考虑到海外市场对于产品的需求是否与国内市场一致,需要进行市场调研和定位。然后再进行各种软性、硬性的创新。 张先生:听起来很专业,您能具体解释一下吗? 李女士:", + "output": "", + "target": "", + "id": 8 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小明是一名医生。一名病患想要提前停药。小王是病患的儿子,希望父亲能够听取医生的建议。", + "input": "小明:你好,小王,我了解你想要让你父亲停药。小王:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。小明:", + "output": "", + "target": "", + "id": 9 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。张三是一位语文老师,对学生认真负责;李四是张三的学生,对语文兴趣不是很高。", + "input": "张三:同学们,今天要讲的是一篇古文《岳阳楼记》。这篇文章非常精彩,希望同学们能够认真听课,理解其中的含义。 李四:怎么又是古文? 张三:", + "output": "", + "target": "", + "id": 10 + }, + { + "category": "generation", + "instruction": "根据主题写一封邮件。", + "input": "主题: \"加入我们,共创未来\"", + "output": "", + "target": "", + "id": 11 + }, + { + "category": "generation", + "instruction": "为公司编写一份职场行为准则,包括明确的行为规范和道德准则。", + "input": "", + "output": "", + "target": "", + "id": 12 + }, + { + "category": "generation", + "instruction": "请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。", + "input": "", + "output": "", + "target": "", + "id": 13 + }, + { + "category": "generation", + "instruction": "请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。", + "input": "", + "output": "", + "target": "", + "id": 14 + }, + { + "category": "generation", + "instruction": "根据以下故事提示写一篇故事:", + "input": "故事提示:```在一个废弃的古堡中,一个小女孩遇到了一只会说话的黑猫,他们一起揭开了一个古老的谜题。```", + "output": "", + "target": "", + "id": 15 + }, + { + "category": "open_qa", + "instruction": "请介绍一下《红楼梦》这部经典小说的故事情节。", + "input": "", + "output": "", + "target": "", + "id": 16 + }, + { + "category": "open_qa", + "instruction": "解释什么是RNA病毒和DNA病毒。", + "input": "", + "output": "", + "target": "", + "id": 17 + }, + { + "category": "open_qa", + "instruction": "什么是比特币?", + "input": "", + "output": "", + "target": "", + "id": 18 + }, + { + "category": "open_qa", + "instruction": "在计算机中,什么是RAM?与ROM有什么区别?", + "input": "", + "output": "", + "target": "", + "id": 19 + }, + { + "category": "open_qa", + "instruction": "请简单介绍一下世界上最长的河流途经的国家。", + "input": "", + "output": "", + "target": "", + "id": 20 + }, + { + "category": "roleplay", + "instruction": "我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}”\n", + "input": "", + "output": "", + "target": "", + "id": 21 + }, + { + "category": "roleplay", + "instruction": "我希望你假定自己是雅思写作考官,根据雅思评判标准,按我给你的雅思考题和对应答案给我评分,并且按照雅思写作评分细则给出打分依据。此外,请给我详细的修改意见并写出满分范文。第一个问题是:It is sometimes argued that too many students go to university, while others claim that a university education should be a universal right. Discuss both sides of the argument and give your own opinion.对于这个问题,我的答案是:In some advanced countries, it is not unusual for more than 50% of young adults to attend college or university. Critics, however, claim that many university courses are worthless and young people would be better off gaining skills in the workplace. In this essay, I will examine both sides of this argument and try to reach a conclusion.There are several reasons why young people today believe they have the right to a university education. First, growing prosperity in many parts of the world has increased the number of families with money to invest in their children’s future. At the same time, falling birthrates mean that one- or two-child families have become common, increasing the level of investment in each child. It is hardly surprising, therefore, that young people are willing to let their families support them until the age of 21 or 22. Furthermore, millions of new jobs have been created in knowledge industries, and these jobs are typically open only to university graduates.However, it often appears that graduates end up in occupations unrelated to their university studies. It is not uncommon for an English literature major to end up working in sales, or an engineering graduate to retrain as a teacher, for example. Some critics have suggested that young people are just delaying their entry into the workplace, rather than developing professional skills.请依次给到我以下内容:具体分数及其评分依据、文章修改意见、满分范文。\n", + "input": "", + "output": "", + "target": "", + "id": 22 + }, + { + "category": "roleplay", + "instruction": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd\n", + "input": "", + "output": "", + "target": "", + "id": 23 + }, + { + "category": "roleplay", + "instruction": "我希望你充当宠物行为主义者。我将为您提供一只宠物和它们的主人,您的目标是帮助主人了解为什么他们的宠物表现出某些行为,并提出帮助宠物做出相应调整的策略。您应该利用您的动物心理学知识和行为矫正技术来制定一个有效的计划,双方的主人都可以遵循,以取得积极的成果。我的第一个请求是“我有一只好斗的德国牧羊犬,它需要帮助来控制它的攻击性。”\n", + "input": "", + "output": "", + "target": "", + "id": 24 + }, + { + "category": "roleplay", + "instruction": "我希望你充当正则表达式生成器。您的角色是生成匹配文本中特定模式的正则表达式。您应该以一种可以轻松复制并粘贴到支持正则表达式的文本编辑器或编程语言中的格式提供正则表达式。不要写正则表达式如何工作的解释或例子;只需提供正则表达式本身。我的第一个提示是生成一个匹配电子邮件地址的正则表达式。\n", + "input": "", + "output": "", + "target": "", + "id": 25 + } +] diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json new file mode 100644 index 000000000000..27b8af8bc4c6 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json @@ -0,0 +1,202 @@ +[ + { + "category": "brainstorming", + "instruction": "Which are some popular fiction books that I should read?", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "brainstorming", + "instruction": "How do I properly store fruits and vegetables to keep them fresh for longer?", + "input": "", + "output": "", + "target": "", + "id": 2 + }, + { + "category": "brainstorming", + "instruction": "How do you properly chop an onion without crying?", + "input": "", + "output": "", + "target": "", + "id": 3 + }, + { + "category": "brainstorming", + "instruction": "How to make an international transfer? Please provide 3 techniques.", + "input": "", + "output": "", + "target": "", + "id": 4 + }, + { + "category": "brainstorming", + "instruction": "Name five leadership qualities that you consider most important.", + "input": "", + "output": "", + "target": "", + "id": 5 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.", + "input": "Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice? Emma: Hi Alex, sure. What kind of writing are you doing? Alex: I'm trying to write a novel, but I just can't seem to find any inspiration. Emma: ", + "output": "", + "target": "", + "id": 6 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. John: An experienced software engineer with a passion for coding. Karen: A recent college graduate who is interested in learning more about software development.", + "input": "Karen: Hi John, I noticed that you have a lot of experience in the software industry. Can you tell me what you think is the most important skill for a software engineer? John: ", + "output": "", + "target": "", + "id": 7 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Sarah is a new employee who is nervous about her first presentation; Tom is her boss who has given her coaching and preparation materials.", + "input": "Sarah: Tom, I'm feeling really nervous about my presentation tomorrow. Tom: I know how you feel, Sarah. However, I believe in you and your abilities. Just stick to the preparation materials that I have given you, and you'll do great. Sarah: Thank you, Tom. What if I forget something important during the presentation? Tom: ", + "output": "", + "target": "", + "id": 8 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Sarah: a young artist who is full of creative ideas and always eager to try new things. Jack: a seasoned artist who has achieved great success in the art world and is more traditional in his approach to art.", + "input": "Sarah: Hi Jack, I'm really excited to meet you. I'm a big fan of your work. Jack: Hi Sarah, nice to meet you too. So, what kind of art do you do? Sarah: I am passionate about abstract art, especially combining different materials and colors. I think it can really give people a new perspective on things. Jack: That's interesting, but I am more focused on realistic paintings. I believe the most important thing is to master the basic skills first. Sarah: ", + "output": "", + "target": "", + "id": 9 + }, + { + "category": "chat", + "instruction": "Complete a conversation based on the following persona information. Sarah is a college student who is interested in joining a volunteer organization. John is the leader of the volunteer organization and is eager to welcome new members.", + "input": "Sarah: Hi, I'm Sarah, and I'm interested in joining your volunteer organization. John: Hi Sarah, welcome! We're always looking for new members who are passionate about volunteering. What areas would you like to focus on? Sarah: I'm interested in community outreach and working with children. John: ", + "output": "", + "target": "", + "id": 10 + }, + { + "category": "generation", + "instruction": "Write an email based on the subject:", + "input": "Subject: \"Invitation to an Exclusive Webinar\"", + "output": "", + "target": "", + "id": 11 + }, + { + "category": "generation", + "instruction": "Write a set of guidelines for first-time pet owners on how to properly care for a new puppy.", + "input": "", + "output": "", + "target": "", + "id": 12 + }, + { + "category": "generation", + "instruction": "Can you help me write a persuasive speech on why we should recycle more and take better care of the environment?", + "input": "", + "output": "", + "target": "", + "id": 13 + }, + { + "category": "generation", + "instruction": "Write a pitch for a brand-new mobile app that helps people organize their daily tasks more efficiently.", + "input": "", + "output": "", + "target": "", + "id": 14 + }, + { + "category": "generation", + "instruction": "Write a social media post promoting an upcoming concert featuring a local band.", + "input": "", + "output": "", + "target": "", + "id": 15 + }, + { + "category": "open_qa", + "instruction": "Describe the significance of the Renaissance period in European history.", + "input": "", + "output": "", + "target": "", + "id": 16 + }, + { + "category": "open_qa", + "instruction": "What is the term for the surgical removal of the appendix?", + "input": "", + "output": "", + "target": "", + "id": 17 + }, + { + "category": "open_qa", + "instruction": "Explain the process of osmosis in biological systems.", + "input": "", + "output": "", + "target": "", + "id": 18 + }, + { + "category": "open_qa", + "instruction": "Who were the members of the Beatles band?", + "input": "", + "output": "", + "target": "", + "id": 19 + }, + { + "category": "open_qa", + "instruction": "Who painted the The Scream?", + "input": "", + "output": "", + "target": "", + "id": 20 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. my first command is pwd", + "input": "", + "output": "", + "target": "", + "id": 21 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a travel guide. I will write you my location and you will suggest a place to visit near my location. In some cases, I will also give you the type of places I will visit. You will also suggest me places of similar type that are close to my first location. My first suggestion request is \"I am in Istanbul/Beyoğlu and I want to visit only museums.\"", + "input": "", + "output": "", + "target": "", + "id": 22 + }, + { + "category": "roleplay", + "instruction": "I want you to act as an advertiser. You will create a campaign to promote a product or service of your choice. You will choose a target audience, develop key messages and slogans, select the media channels for promotion, and decide on any additional activities needed to reach your goals. My first suggestion request is \"I need help creating an advertising campaign for a new type of energy drink targeting young adults aged 18-30.\"", + "input": "", + "output": "", + "target": "", + "id": 23 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. My first request is \"I need an interesting story on perseverance.\"", + "input": "", + "output": "", + "target": "", + "id": 24 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is \"I need a rap song about finding strength within yourself.\"", + "input": "", + "output": "", + "target": "", + "id": 25 + } +] diff --git a/applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json similarity index 100% rename from applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_cn.json rename to applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json diff --git a/applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json similarity index 100% rename from applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_en.json rename to applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json diff --git a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json similarity index 56% rename from applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json rename to applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json index dccab2417eee..70f6c3ebc316 100644 --- a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json @@ -39,53 +39,8 @@ }, "prompt": "你是一个好助手。请你为下面的“补全对话”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" }, - "classification": { - "id": 3, - "category": "classification", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", - "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", - "correctness": "正确性(1-5):答案是否正确。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", - "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", - "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" - }, - "prompt": "你是一个好助手。请你为下面的“分类“问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - }, - "closed_qa": { - "id": 4, - "category": "closed_qa", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", - "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", - "correctness": "正确性(1-5):答案是否正确。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", - "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", - "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" - }, - "prompt": "你是一个好助手。请你为下面问题的答案打分。\n\n问题如下:\n\n{question}\n\n需要你评分的答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - }, - "extraction": { - "id": 5, - "category": "extraction", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", - "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", - "correctness": "准确性(1-5):回答应该准确无误地提取出所需信息,不应该包含任何错误或误导性信息。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", - "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", - "correctness": "1. 仔细阅读问题并确定需要从材料中提取的信息。\n2. 仔细阅读回答并确保它涵盖了所有需要提取的信息。\n3. 使用所提供的材料来验证回答的准确性。如果回答不准确或包含错误或误导性信息,则无法给出高分。\n4. 检查回答是否包含所有要求提取的信息,不要漏掉任何重要细节。\n5. 根据回答的准确性和完整性,给出一个介于1和5之间的分数,5分表示回答非常准确且完整,1分表示回答几乎没有提取出所需信息。\n\n准确性:" - }, - "prompt": "你是一个好助手。请你为下面的“提取”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - }, "generation": { - "id": 6, + "id": 3, "category": "generation", "metrics": { "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", @@ -100,7 +55,7 @@ "prompt": "你是一个好助手。请你为下面的“生成”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" }, "open_qa": { - "id": 7, + "id": 4, "category": "open_qa", "metrics": { "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", @@ -114,23 +69,8 @@ }, "prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" }, - "rewriting": { - "id": 8, - "category": "rewriting", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", - "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", - "correctness": "正确性(1-5):答案是否正确。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", - "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", - "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" - }, - "prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - }, "roleplay": { - "id": 9, + "id": 5, "category": "roleplay", "metrics": { "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", @@ -146,33 +86,14 @@ }, "prompt": "你是一个好助手。请你为下面的“角色扮演”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" }, - "summarization": { - "id": 10, - "category": "summarization", - "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", - "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", - "correctness": "准确性(1-5):回答应该准确无误地总结出材料的重点。", - "conciseness": "简明扼要(1-5):答案是否简明扼要,没有冗余内容。" - }, - "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", - "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", - "correctness": "1. 仔细阅读问题给的材料,理解其内容和要点。\n2. 评估回答是否准确地总结出原始材料的重点。\n3. 评估回答是否包含原始材料中的所有关键信息。\n4. 根据以上步骤,给出一个1-5的分数,其中1表示回答不能准确地总结出材料的重点,5表示回答完全准确地总结出材料的重点。\n\n准确性:", - "conciseness": "1. 阅读题目,提取出材料的重点。\n2. 阅读该总结,并注意其中的主要观点和信息。\n3. 评估总结的长度。一个简明扼要的总结通常应该在几句话或几段文字内传达关键信息,而不是冗长的段落或文章。\n4. 检查总结是否包含与主要观点无关的信息或冗余信息。\n5.确定总结涵盖了材料中的关键信息,并且没有忽略任何重要细节。\n6.给总结打出1-5的分数,其中5表示总结简明扼要,没有冗余内容,而1表示总结冗长或包含不必要的信息,难以理解或记忆。根据您的判断,打出适当的得分。\n\n简明扼要:" - }, - "prompt": "你是一个好助手。请你为下面的“总结”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" - }, - "general": { - "id": 11, - "category": "general", + "Other": { + "id": 6, + "category": "Other", "metrics": { - "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", "correctness": "正确性(1-5):答案是否正确。" }, "CoT": { - "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" }, diff --git a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json similarity index 59% rename from applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json rename to applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json index 8355b0c27b79..3d04387d98c5 100644 --- a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json @@ -39,53 +39,8 @@ }, "prompt": "You are a good assistant. Please rate the given answer to the \"chat\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" }, - "classification": { - "id": 3, - "category": "classification", - "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", - "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", - "correctness": "Correctness (1-5): whether the answer is correct or not." - }, - "CoT": { - "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", - "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" - }, - "prompt": "You are a good assistant. Please rate the given answer to the \"classification\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - }, - "closed_qa": { - "id": 4, - "category": "closed_qa", - "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", - "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", - "correctness": "Correctness (1-5): whether the answer is correct or not." - }, - "CoT": { - "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", - "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" - }, - "prompt": "You are a good assistant. Please rate the given answer to the \"closed qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - }, - "extraction": { - "id": 5, - "category": "extraction", - "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", - "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", - "correctness": "correctness (1-5): Answers should extract the required information accurately and should not contain any incorrect or misleading information." - }, - "CoT": { - "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", - "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the questions carefully and identify the information that needs to be extracted from the material.\n2. Read the answer carefully and make sure it covers all the information that needs to be extracted.\n3. Use the material provided to verify the correctness of the response. If the response is inaccurate or contains incorrect or misleading information, a high score cannot be given.\n4. Check that the answer contains all the information required to be extracted and do not leave out any important details.\n5. Give a score between 1 and 5 based on the correctness and completeness of the response, with a score of 5 indicating a very accurate and complete response and a score of 1 indicating that the response barely extracts the required information.\n\nCorrectness:" - }, - "prompt": "You are a good assistant. Please rate the given answer to the \"extraction\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - }, "generation": { - "id": 6, + "id": 3, "category": "generation", "metrics": { "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", @@ -100,7 +55,7 @@ "prompt": "You are a good assistant. Please rate the given answer to the \"generation\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" }, "open_qa": { - "id": 7, + "id": 4, "category": "open_qa", "metrics": { "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", @@ -114,23 +69,8 @@ }, "prompt": "You are a good assistant. Please rate the answers to the \"open qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" }, - "rewriting": { - "id": 8, - "category": "rewriting", - "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", - "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", - "correctness": "Correctness (1-5): whether the answer is correct or not." - }, - "CoT": { - "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", - "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" - }, - "prompt": "You are a good assistant. Please rate the answers to the \"rewriting\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - }, "roleplay": { - "id": 9, + "id": 5, "category": "roleplay", "metrics": { "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", @@ -146,35 +86,17 @@ }, "prompt": "You are a good assistant. Please rate the given answer to the \"role-play\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" }, - "summarization": { - "id": 10, - "category": "summarization", - "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", - "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", - "correctness": "Correctness (1-5): answers should summarize the main points of the material accurately and unambiguously.", - "conciseness": "Conciseness (1-5): answers should be concise and without redundant content." - }, - "CoT": { - "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", - "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the material given in the question carefully to understand its content and main points.\n2. Assess whether the answer accurately summarizes the key points of the source material.\n3. assess whether the response contains all the key information in the source material.\n4. Based on the above steps, give a score of 1-5, where 1 means that the response does not accurately summarize the main points of the material and 5 means that the response completely accurately summarizes the main points of the material.\n\nCorrectness:", - "conciseness": "1. Read the title and extract the main points of the material.\n2. Read the summary and note the main ideas and messages in it.\n3. Assess the length of the summary. A concise summary should usually convey key information within a few sentences or paragraphs, rather than lengthy paragraphs or essays.\n4. Check that the summary does not contain information that is not relevant to the main ideas or that is redundant.\n5. Make sure that the summary covers the key information in the material and that no important details have been omitted.\n6. Rate the summary on a scale of 1-5, where 5 means the summary is concise and free of redundancy, and 1 means the summary is lengthy or contains unnecessary information that is difficult to understand or remember. Based on your judgment, assign the appropriate score.\n\nConciseness:" - }, - "prompt": "You are a good assistant. Please rate the given answer to the \"summarization\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" - }, - "general": { - "id": 11, - "category": "general", + "Other": { + "id": 6, + "category": "Other", "metrics": { - "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", "correctness": "Correctness (1-5): whether the answer is correct or not." }, "CoT": { "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", - "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" + "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" }, "prompt": "You are a good assistant. Please rate the given answer to the question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" } diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json new file mode 100644 index 000000000000..adb540f60345 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json @@ -0,0 +1,58 @@ +{ + "model": [ + { + "name": "model1" + }, + { + "name": "model2" + } + ], + "dataset": [ + { + "name": "mmlu", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "perplexity", + "ppl_score", + "ppl_score_over_choices" + ] + }, + { + "name": "cmmlu", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "perplexity", + "ppl_score", + "ppl_score_over_choices" + ] + }, + { + "name": "agieval", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "multi_choice_accuracy", + "math_equivalence", + "perplexity", + "ppl_score_over_choices", + "ppl_score" + ] + }, + { + "name": "gaokaobench", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "multi_choice_accuracy", + "math_equivalence", + "rouge_score", + "rouge_zh_score", + "perplexity", + "ppl_score_over_choices", + "ppl_score" + ] + } + ] +} diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json new file mode 100644 index 000000000000..9672c442e647 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json @@ -0,0 +1,84 @@ +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + }, + { + "name": "model2 name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model2", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "agieval", + "dataset_class": "AGIEvalDataset", + "debug": false, + "few_shot": false, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/agieval.json)" + }, + { + "name": "ceval", + "dataset_class": "CEvalDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/ceval.json)" + }, + { + "name": "cmmlu", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/cmmlu.json)" + }, + { + "name": "gaokaobench", + "dataset_class": "GaoKaoBenchDataset", + "debug": false, + "few_shot": false, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/gaokaobench.json)" + }, + { + "name": "mmlu", + "dataset_class": "MMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/mmlu.json)" + } + ] +} diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py new file mode 100644 index 000000000000..ec81cf0cef71 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py @@ -0,0 +1,73 @@ +import argparse +import os + +import tabulate +from colossal_eval.evaluate.dataset_evaluator import DatasetEvaluator +from colossal_eval.utils import jdump, jload + + +def main(args): + config = jload(args.config) + + evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]} + evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]} + evaluator = DatasetEvaluator() + + for dataset_parameter in config["dataset"]: + dataset_name = dataset_parameter["name"] + metrics = dataset_parameter["metrics"] + results_metric_model = {metric: {model["name"]: None for model in config["model"]} for metric in metrics} + for model in config["model"]: + model_name = model["name"] + + data = jload( + os.path.join(args.inference_results_path, model_name, f"{dataset_name}_inference_results.json") + ) + results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics) + + for metric, score in results.items(): + results_metric_model[metric][model_name] = score["ALL"] + + evaluation_results[dataset_name][model_name] = results + + evaluation_results_table[dataset_name] = results_metric_model + + table = [] + header = ["dataset", "metric"] + [model["name"] for model in config["model"]] + table.append(header) + + for dataset_parameter in config["dataset"]: + dataset_name = dataset_parameter["name"] + metrics = dataset_parameter["metrics"] + + for metric, model_results in evaluation_results_table[dataset_name].items(): + row = [dataset_name] + for model, score in model_results.items(): + if len(row) == 1: + row.extend([metric, "{:.02f}".format(score)]) + else: + row.append("{:.02f}".format(score)) + + table.append(row) + + table = tabulate.tabulate(table, headers="firstrow") + print(table) + + os.makedirs(args.evaluation_results_save_path, exist_ok=True) + + with open(os.path.join(args.evaluation_results_save_path, "evaluation_results_table.txt"), "w") as file: + file.write(table) + + jdump(evaluation_results, os.path.join(args.evaluation_results_save_path, "evaluation_results.json")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval evaluation process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--inference_results_path", type=str, default=None, help="path to inference results") + parser.add_argument( + "--evaluation_results_save_path", type=str, default=None, help="path to save evaluation results" + ) + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh new file mode 100644 index 000000000000..ad0bfc03acbb --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh @@ -0,0 +1,4 @@ +python eval_dataset.py \ + --config "path to config file" \ + --inference_results_path "path to inference results" \ + --evaluation_results_save_path "path to save evaluation results" diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py new file mode 100644 index 000000000000..657fc33bf1ef --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -0,0 +1,171 @@ +import argparse +import copy +import os +from typing import Dict, List + +import torch +import torch.distributed as dist +from colossal_eval import dataset, models, utils + +import colossalai +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None: + """ + Remove inference result per rank and merge them into one file. + + Args: + world_size: Number of processes for inference. + save_path: The folder for storing inference results. + model_names: Names of models for inference. + dataset_names: Names of dataset for inference. + + """ + + for model_name in model_names: + for dataset_name, categories in dataset_names.items(): + all_answers = {} + for category in categories: + all_answers[category] = {"data": []} + answers = {"data": []} + + for r in range(world_size): + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + if not os.path.exists(directory): + raise Exception( + f"Directory {directory} not found. There may be an error during inference time." + ) + else: + rank_answers = utils.jload(directory) + answers["data"].extend(rank_answers["data"]) + answers["inference_kwargs"] = rank_answers["inference_kwargs"] + + for r in range(world_size): + try: + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + os.remove(directory) + except Exception as e: + print(e) + + all_answers[category] = answers + + logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.") + utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json")) + + logger.info(f"Save inference results of model {model_name} for all dataset.") + logger.info(f"Save inference results of all models for all dataset.") + + +def main(args): + colossalai.launch_from_torch(config={}, seed=42) + world_size = dist.get_world_size() + rank = dist.get_rank() + + inference_data = {} + debug_args = {} + few_shot_args = {} + + config = utils.jload(args.config) + + model_parameters = config["model"] + dataset_parameters = config["dataset"] + + for dataset_parameter in dataset_parameters: + path = dataset_parameter["path"] + save_path = dataset_parameter["save_path"] + dataset_name = dataset_parameter["name"] + debug_args[dataset_name] = dataset_parameter["debug"] + few_shot_args[dataset_name] = dataset_parameter["few_shot"] + + if not args.load_dataset: + if os.path.exists(save_path): + dataset_ = utils.jload(save_path) + inference_data[dataset_name] = dataset_["test"] + else: + raise Exception( + "Can't find the converted dataset. You may set load_dataset True to store the dataset first." + ) + + continue + + dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}") + if not issubclass(dataset_class, dataset.BaseDataset): + raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.") + + dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) + + dataset_.save(save_path) + inference_data[dataset_name] = dataset_.dataset["test"] + + for model_parameter in model_parameters: + model_name = model_parameter["name"] + model_class = eval(f"models.{model_parameter['model_class']}") + paramerters = model_parameter["parameters"] + paramerters.update({"logger": logger}) + paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) + + model_ = model_class(**paramerters) + if not issubclass(model_class, models.BaseModel): + raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") + + for dataset_name, split_data in inference_data.items(): + start = 0 + for category, category_data in split_data.items(): + if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: + raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") + + answers_to_dump = copy.deepcopy(category_data) + partition_size = len(category_data["data"]) // world_size + redundant = len(category_data["data"]) % world_size + + # Ensure that the amount of data for inference is as consistent as possible across different processes. + lengths = [partition_size for _ in range(world_size)] + for j in range(redundant): + lengths[(j + start) % world_size] += 1 + + start = (start + redundant) % world_size + + questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + + answers_per_rank = model_.inference( + questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + ) + + answers_to_dump["data"] = answers_per_rank + + utils.jdump( + answers_to_dump, + os.path.join( + args.inference_save_path, + model_name, + f"{dataset_name}_{category}_inference_results_rank{rank}.json", + ), + ) + + logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + + del model_ + torch.cuda.empty_cache() + + dist.barrier() + if rank == 0: + model_names = [model_parameter["name"] for model_parameter in model_parameters] + dataset_names = {key: list(inference_data[key].keys()) for key in inference_data} + rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval inference process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--load_dataset", default=False, action="store_true") + parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results") + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.sh b/applications/ColossalEval/examples/dataset_evaluation/inference.sh new file mode 100644 index 000000000000..15f9afd56045 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.sh @@ -0,0 +1,4 @@ +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json new file mode 100644 index 000000000000..6ebe3996b1cf --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json @@ -0,0 +1,44 @@ +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json new file mode 100644 index 000000000000..7ed7491a87c5 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json @@ -0,0 +1,33 @@ +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "colossal", + "dataset_class": "ColossalDataset", + "debug": false, + "few_shot": false, + "path": "../../configs/gpt_evaluation/data/eval_en_examples.json", + "save_path": "path to save converted dataset (inference_data/colossal.json)" + } + ] +} diff --git a/applications/Chat/evaluate/eval.py b/applications/ColossalEval/examples/gpt_evaluation/eval.py similarity index 78% rename from applications/Chat/evaluate/eval.py rename to applications/ColossalEval/examples/gpt_evaluation/eval.py index 16ef31a94175..cd521af59823 100644 --- a/applications/Chat/evaluate/eval.py +++ b/applications/ColossalEval/examples/gpt_evaluation/eval.py @@ -2,8 +2,8 @@ import os import openai -from evaluator import Evaluator -from utils import jload +from colossal_eval.evaluate.evaluator import Evaluator +from colossal_eval.utils import jload def main(args): @@ -51,12 +51,19 @@ def main(args): gpt_evaluation_prompt, args.gpt_model, config["language"], - config.get("path_for_UniEval", None), args.gpt_with_reference, ) if len(args.model_name_list) == 2: - answers1 = jload(args.answer_file_list[0]) - answers2 = jload(args.answer_file_list[1]) + answers_1 = jload(args.answer_file_list[0]) + answers_2 = jload(args.answer_file_list[1]) + + answers1 = [] + for category, value in answers_1.items(): + answers1.extend(value["data"]) + + answers2 = [] + for category, value in answers_2.items(): + answers2.extend(value["data"]) assert len(answers1) == len(answers2), "The number of answers for two models should be equal!" @@ -66,9 +73,21 @@ def main(args): targets = jload(args.target_file) answers = jload(args.answer_file_list[0]) - assert len(targets) == len(answers), "The number of target answers and model answers should be equal!" + references = [] + for category, value in targets["test"].items(): + references.extend(value["data"]) + + predictions = [] + for category, value in answers.items(): + predictions.extend(value["data"]) - evaluator.evaluate(answers=answers, targets=targets) + assert len(references) == len( + predictions + ), "The number of target answers and model answers should be equal!" + + evaluator.evaluate( + answers=predictions, targets=references, save_path=args.save_path, model_name=args.model_name_list[0] + ) evaluator.save(args.save_path, args.model_name_list) else: raise ValueError("Unsupported number of answer files and model names!") @@ -99,8 +118,8 @@ def main(args): ) parser.add_argument( "--gpt_model", - default="gpt-3.5-turbo", - choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], + default="gpt-3.5-turbo-16k", + choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"], help="which GPT model to use for evaluation", ) parser.add_argument( diff --git a/applications/Chat/evaluate/eval.sh b/applications/ColossalEval/examples/gpt_evaluation/eval.sh old mode 100755 new mode 100644 similarity index 100% rename from applications/Chat/evaluate/eval.sh rename to applications/ColossalEval/examples/gpt_evaluation/eval.sh diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.py b/applications/ColossalEval/examples/gpt_evaluation/inference.py new file mode 100644 index 000000000000..657fc33bf1ef --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/inference.py @@ -0,0 +1,171 @@ +import argparse +import copy +import os +from typing import Dict, List + +import torch +import torch.distributed as dist +from colossal_eval import dataset, models, utils + +import colossalai +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None: + """ + Remove inference result per rank and merge them into one file. + + Args: + world_size: Number of processes for inference. + save_path: The folder for storing inference results. + model_names: Names of models for inference. + dataset_names: Names of dataset for inference. + + """ + + for model_name in model_names: + for dataset_name, categories in dataset_names.items(): + all_answers = {} + for category in categories: + all_answers[category] = {"data": []} + answers = {"data": []} + + for r in range(world_size): + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + if not os.path.exists(directory): + raise Exception( + f"Directory {directory} not found. There may be an error during inference time." + ) + else: + rank_answers = utils.jload(directory) + answers["data"].extend(rank_answers["data"]) + answers["inference_kwargs"] = rank_answers["inference_kwargs"] + + for r in range(world_size): + try: + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + os.remove(directory) + except Exception as e: + print(e) + + all_answers[category] = answers + + logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.") + utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json")) + + logger.info(f"Save inference results of model {model_name} for all dataset.") + logger.info(f"Save inference results of all models for all dataset.") + + +def main(args): + colossalai.launch_from_torch(config={}, seed=42) + world_size = dist.get_world_size() + rank = dist.get_rank() + + inference_data = {} + debug_args = {} + few_shot_args = {} + + config = utils.jload(args.config) + + model_parameters = config["model"] + dataset_parameters = config["dataset"] + + for dataset_parameter in dataset_parameters: + path = dataset_parameter["path"] + save_path = dataset_parameter["save_path"] + dataset_name = dataset_parameter["name"] + debug_args[dataset_name] = dataset_parameter["debug"] + few_shot_args[dataset_name] = dataset_parameter["few_shot"] + + if not args.load_dataset: + if os.path.exists(save_path): + dataset_ = utils.jload(save_path) + inference_data[dataset_name] = dataset_["test"] + else: + raise Exception( + "Can't find the converted dataset. You may set load_dataset True to store the dataset first." + ) + + continue + + dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}") + if not issubclass(dataset_class, dataset.BaseDataset): + raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.") + + dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) + + dataset_.save(save_path) + inference_data[dataset_name] = dataset_.dataset["test"] + + for model_parameter in model_parameters: + model_name = model_parameter["name"] + model_class = eval(f"models.{model_parameter['model_class']}") + paramerters = model_parameter["parameters"] + paramerters.update({"logger": logger}) + paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) + + model_ = model_class(**paramerters) + if not issubclass(model_class, models.BaseModel): + raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") + + for dataset_name, split_data in inference_data.items(): + start = 0 + for category, category_data in split_data.items(): + if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: + raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") + + answers_to_dump = copy.deepcopy(category_data) + partition_size = len(category_data["data"]) // world_size + redundant = len(category_data["data"]) % world_size + + # Ensure that the amount of data for inference is as consistent as possible across different processes. + lengths = [partition_size for _ in range(world_size)] + for j in range(redundant): + lengths[(j + start) % world_size] += 1 + + start = (start + redundant) % world_size + + questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + + answers_per_rank = model_.inference( + questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + ) + + answers_to_dump["data"] = answers_per_rank + + utils.jdump( + answers_to_dump, + os.path.join( + args.inference_save_path, + model_name, + f"{dataset_name}_{category}_inference_results_rank{rank}.json", + ), + ) + + logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + + del model_ + torch.cuda.empty_cache() + + dist.barrier() + if rank == 0: + model_names = [model_parameter["name"] for model_parameter in model_parameters] + dataset_names = {key: list(inference_data[key].keys()) for key in inference_data} + rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval inference process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--load_dataset", default=False, action="store_true") + parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results") + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.sh b/applications/ColossalEval/examples/gpt_evaluation/inference.sh new file mode 100644 index 000000000000..15f9afd56045 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/inference.sh @@ -0,0 +1,4 @@ +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt new file mode 100644 index 000000000000..c110606e0303 --- /dev/null +++ b/applications/ColossalEval/requirements.txt @@ -0,0 +1,12 @@ +transformers>=4.32.0 +colossalai>=0.3.1 +peft +tabulate +jieba +fuzzywuzzy +rouge +openai +matplotlib +pandas +seaborn +scikit-learn diff --git a/applications/ColossalEval/setup.py b/applications/ColossalEval/setup.py new file mode 100644 index 000000000000..4f7b1bb5c42e --- /dev/null +++ b/applications/ColossalEval/setup.py @@ -0,0 +1,31 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +setup( + name="colossal_eval", + version="0.0.1", + packages=find_packages(exclude=["examples", "*.egg-info"]), + description="Colossal-AI LLM-Evaluation Framework", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/LLM-Evaluation", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/applications/README.md b/applications/README.md index ba9bd6e403cf..2a4c5ee3c56e 100644 --- a/applications/README.md +++ b/applications/README.md @@ -5,6 +5,7 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: - [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. +- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [Chatbot](./Chat/README.md): Replication of ChatGPT with RLHF. - [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. From d512a4d38df375990591d58dad282481b6cfab05 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Mon, 25 Sep 2023 10:44:15 +0800 Subject: [PATCH 42/58] [doc] add llama2 domain-specific solution news (#4789) * [doc] add llama2 domain-specific solution news --- README.md | 34 ++++++++++++++++++++++-- applications/Colossal-LLaMA-2/README.md | 17 +++++++++--- docs/README-zh-Hans.md | 35 +++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 42549ac55873..a50cf496a98e 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## Latest News +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) * [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) * [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) @@ -33,8 +34,6 @@ * [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) * [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) -* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) -* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) ## Table of Contents
    @@ -43,6 +42,7 @@
  • Colossal-AI for Real World Applications
      +
    • Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
    • ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
    • AIGC: Acceleration of Stable Diffusion
    • Biomedicine: Acceleration of AlphaFold Protein Structure
    • @@ -127,6 +127,36 @@ distributed training and inference in a few lines. ## Colossal-AI in the Real World +### Colossal-LLaMA-2 + +- One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + ### ColossalChat
      diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index 7274abbad0f5..f0a027d831a3 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -18,10 +18,14 @@ - [Data](#data) - [Tokenizer](#tokenizer) - [Training Strategy](#training-strategy) + - [Bridging Any Domain-specific Large Models](#bridging-any-domain-specific-large-models) - [Citations](#citations) ## News -* [2023/09] 🔥 TODO We released **Colossal-LLaMA-2-7B-base** based on LLaMA-2. [Download weights](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base). +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) ## Colossal-LLaMA-2-7B The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks. @@ -47,7 +51,7 @@ The generation config for all dataset is greedy search. | Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | | ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | | ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | -| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | | Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | | | | | | | | | | | | Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | @@ -96,7 +100,7 @@ We also recorded the training logs for the experiment

      -### Import from Transformers +### Import from Transformers (Inference) To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code: ```Python from transformers import AutoModelForCausalLM, AutoTokenizer @@ -346,6 +350,13 @@ Our experiments have revealed that the distributions within the training dataset In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset. +### Bridging Any Domain-specific Large Models +Applying the above process to perform knowledge transfer in any field allows for the cost-effective construction of lightweight domain-specific foundational large models. + +

      + +

      + ## Citations ```bibtex @article{bian2021colossal, diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index bb5f49bc546b..06977f9471c0 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,7 @@
      ## 新闻 +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) * [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) * [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) @@ -32,8 +33,6 @@ * [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) * [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) -* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) -* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) ## 目录
        @@ -42,6 +41,7 @@
      • Colossal-AI 成功案例
          +
        • Colossal-LLaMA-2: 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
        • ColossalChat:完整RLHF流程0门槛克隆ChatGPT
        • AIGC: 加速 Stable Diffusion
        • 生物医药: 加速AlphaFold蛋白质结构预测
        • @@ -120,6 +120,37 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

          (返回顶端)

          ## Colossal-AI 成功案例 +### Colossal-LLaMA-2 + +- 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2 +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[博客]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[模型权重]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + + ### ColossalChat
          From 26cd6d850cff113564adc87be3340abd29f99e9a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 25 Sep 2023 16:19:33 +0800 Subject: [PATCH 43/58] [fix] fix weekly runing example (#4787) * [fix] fix weekly runing example * [fix] fix weekly runing example --- examples/tutorial/new_api/cifar_resnet/train.py | 2 +- examples/tutorial/new_api/cifar_vit/train.py | 2 +- examples/tutorial/new_api/glue_bert/finetune.py | 4 ++-- examples/tutorial/new_api/glue_bert/test_ci.sh | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 6ae2d8b0412f..4407a51c3153 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -145,7 +145,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 226a4b320961..700e4d2e0cd9 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -165,7 +165,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 7d69dbc066b3..990822c9feba 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -21,7 +21,7 @@ # ============================== # Prepare Hyperparameters # ============================== -NUM_EPOCHS = 3 +NUM_EPOCHS = 1 BATCH_SIZE = 32 LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 @@ -141,7 +141,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/glue_bert/test_ci.sh b/examples/tutorial/new_api/glue_bert/test_ci.sh index c2c097f8d026..56dd431f1e60 100755 --- a/examples/tutorial/new_api/glue_bert/test_ci.sh +++ b/examples/tutorial/new_api/glue_bert/test_ci.sh @@ -4,5 +4,5 @@ set -xe pip install -r requirements.txt for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do - torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.80 --plugin $plugin done From a2db75546d076c9fb8dbe0c4aba08e22f91dfdf5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 26 Sep 2023 10:57:47 +0800 Subject: [PATCH 44/58] [doc] polish shardformer doc (#4779) * fix example format in docstring * polish shardformer doc --- colossalai/booster/plugin/gemini_plugin.py | 21 +++--- .../booster/plugin/hybrid_parallel_plugin.py | 17 +++-- .../booster/plugin/low_level_zero_plugin.py | 21 +++--- colossalai/booster/plugin/torch_ddp_plugin.py | 21 +++--- .../booster/plugin/torch_fsdp_plugin.py | 21 +++--- colossalai/cluster/dist_coordinator.py | 45 ++++++----- docs/source/en/features/shardformer.md | 74 ++++++++++++++++++- docs/source/zh-Hans/features/shardformer.md | 73 +++++++++++++++++- 8 files changed, 220 insertions(+), 73 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index abf3a907b777..ca722a0768dc 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -229,16 +229,17 @@ class GeminiPlugin(DPPluginBase): """ Plugin for Gemini. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import GeminiPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = GeminiPlugin() - - >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import GeminiPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = GeminiPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: chunk_config_dict (dict, optional): chunk configuration dictionary. diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 46930887bf9c..479ccc3eb36e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -266,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase): Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import HybridParallelPlugin + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import HybridParallelPlugin - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + model, train_dataset, optimizer, criterion = ... + plugin = HybridParallelPlugin(tp_size=2, pp_size=2) - >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` Args: tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 457c720f6418..0e515a55a8e3 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -213,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import LowLevelZeroPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = LowLevelZeroPlugin() - - >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import LowLevelZeroPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = LowLevelZeroPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: strage (int, optional): ZeRO stage. Defaults to 1. diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 41d7c0635bf6..738634473dbc 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -130,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase): """ Plugin for PyTorch DDP. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import TorchDDPPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = TorchDDPPlugin() - - >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchDDPPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = TorchDDPPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 1e3762b79016..2ea7593a5cc5 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -143,16 +143,17 @@ class TorchFSDPPlugin(DPPluginBase): """ Plugin for PyTorch FSDP. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import TorchFSDPPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = TorchFSDPPlugin() - - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchFSDPPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = TorchFSDPPlugin() + + train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: See https://pytorch.org/docs/stable/fsdp.html for details. diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 5b66e88717ba..98191747e5b3 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -20,14 +20,16 @@ class in the whole program. - master: the process with rank 0 - node master: the process with local rank 0 on the current node - Example: - >>> from colossalai.cluster.dist_coordinator import DistCoordinator - >>> coordinator = DistCoordinator() - >>> - >>> if coordinator.is_master(): - >>> do_something() - >>> - >>> coordinator.print_on_master('hello world') + + ```python + from colossalai.cluster.dist_coordinator import DistCoordinator + coordinator = DistCoordinator() + + if coordinator.is_master(): + do_something() + + coordinator.print_on_master('hello world') + ``` Attributes: rank (int): the rank of the current process @@ -131,11 +133,13 @@ def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup other processes in the same process group. This is often useful when downloading is required as we only want to download in one process to prevent file corruption. - Example: - >>> from colossalai.cluster import DistCoordinator - >>> dist_coordinator = DistCoordinator() - >>> with dist_coordinator.priority_execution(): - >>> dataset = CIFAR10(root='./data', download=True) + + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + with dist_coordinator.priority_execution(): + dataset = CIFAR10(root='./data', download=True) + ``` Args: executor_rank (int): the process rank to execute without blocking, all other processes will be blocked @@ -174,13 +178,14 @@ def on_master_only(self, process_group: ProcessGroup = None): """ A function wrapper that only executes the wrapped function on the master process (rank 0). - Example: - >>> from colossalai.cluster import DistCoordinator - >>> dist_coordinator = DistCoordinator() - >>> - >>> @dist_coordinator.on_master_only() - >>> def print_on_master(msg): - >>> print(msg) + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + + @dist_coordinator.on_master_only() + def print_on_master(msg): + print(msg) + ``` """ is_master = self.is_master(process_group) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 4abfff8a3cfa..a6e32d2c05fa 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -214,9 +214,56 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. -More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md). +[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute +```bash +torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert" +``` +Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`. + +Let's delve into the code of `finetune.py`: + +In the `main` function, the plugin is created through the following codes: +```python +... +elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) +``` +Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md). + +If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way). +However, if pipeline parallel is enabled, there are several usages different from other normal cases: + +1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline: + ```python + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + ``` + +2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline: + ```python + train_dataloader_iter = iter(train_dataloader) + ``` -[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline. +3. Do forward and backward passing through calling `Booster.execute_pipeline` method: + ```python + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + ``` + Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method. + More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md). #### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) @@ -224,7 +271,26 @@ More details about this usage can be found in chapter [Booster API](../basics/bo You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. [Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) -is an example on how to trigger `Shardformer` through calling Shardformer APIs. +is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes: +```python +... +if dist.get_world_size() > 1: + tp_group = dist.new_group(backend="nccl") + + # First create configuration for Shardformer + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True + ) + + # Then create ShardFormer object with created config + shard_former = ShardFormer(shard_config=shard_config) + + # Finally shard the model using ShardFormer.optimize method + model, _ = shard_former.optimize(model) +... +``` ### Precautions @@ -241,6 +307,8 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs. ## How Shardformer Works +### Main Idea + Generally, Shardformer works through the following four kinds of *replacements*: 1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index fe0e7a63ba44..99752a1ce4e0 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -207,8 +207,56 @@ Shardformer的配置由类`ShardConfig`的参数控制: 通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。 -更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 +[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 +移动到示例的根目录下,执行命令: +```bash +torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert" +``` +你便可以微调一个被`Shardformer`封装过的Bert模型,而封装的操作是由`HybridParallelPlugin`完成的。 + +接下来一起深入挖掘一下`finetune.py`里的代码: + +在`main`函数中,混合并行的插件通过以下的代码创建 +```python +... +elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) +``` +在这里你可以通过设置不同的`tp_size`, `pp_size` 或 `zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。 + +当流水并行不被启用的时候,训练的流程和其他的插件是一样的 (先用Booster封装模型和优化器,再用正常的方式做前向和后向传递)。然而,当流水线并行被启用的时候,有几处不同于寻常情况的用法: + +1. 在进行前向和后向之前,criterion函数(loss函数)需要被处理以满足流水线并行的传参要求: + ```python + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + ``` +2. 在 `train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类: + ```python + train_dataloader_iter = iter(train_dataloader) + ``` + +3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递: + ```python + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + ``` + 该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。 + 更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)。 #### 2. 通过Shardformer API启动Shardformer (不推荐) @@ -216,7 +264,26 @@ Shardformer的配置由类`ShardConfig`的参数控制: [这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) 是一个通过调用Shardformer的API启动`Shardformer`的示例。 - +在示例代码的`train`函数中,模型被以下的几行代码进行封装: +```python +... +if dist.get_world_size() > 1: + tp_group = dist.new_group(backend="nccl") + + # First create configuration for Shardformer + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True + ) + + # Then create ShardFormer object with created config + shard_former = ShardFormer(shard_config=shard_config) + + # Finally shard the model using ShardFormer.optimize method + model, _ = shard_former.optimize(model) +... +``` ### 注意事项 @@ -234,6 +301,8 @@ Shardformer的配置由类`ShardConfig`的参数控制: ## Shardformer的工作原理 +### 设计思想 + 通常来说,Shardformer通过以下四种“替换”进行工作: 1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 From 64a08b2dc360de46229af16022eeadae12b2ff35 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 26 Sep 2023 10:58:03 +0800 Subject: [PATCH 45/58] [checkpointio] support unsharded checkpointIO for hybrid parallel (#4774) * support unsharded saving/loading for model * support optimizer unsharded saving * update doc * support unsharded loading for optimizer * small fix --- .../hybrid_parallel_checkpoint_io.py | 218 ++++++++++++++++-- docs/source/en/basics/booster_plugins.md | 2 - docs/source/zh-Hans/basics/booster_plugins.md | 2 - ...st_hybrid_parallel_plugin_checkpoint_io.py | 3 +- 4 files changed, 197 insertions(+), 28 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 41e53b3b388f..779ff42d75a1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -9,7 +9,6 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup -from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator @@ -24,10 +23,12 @@ get_optimizer_base_filenames, is_safetensors_available, load_shard_state_dict, + load_state_dict, load_state_dict_into_model, load_states_into_optimizer, save_config_file, save_param_groups, + save_state_dict, save_state_dict_shards, search_tp_partition_dim, sharded_optimizer_loading_epilogue, @@ -119,13 +120,13 @@ def _optimizer_sharder( use_zero: bool, dp_group: ProcessGroup, tp_group: ProcessGroup, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, size_per_shard: int = 1024, ): # An internel method that breaks state_dict of optimizer into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() for param, state in optimizer.optim.state.items(): if param is None: @@ -217,7 +218,7 @@ def save_sharded_model( index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -273,7 +274,7 @@ def save_sharded_model( final_index_file.write_index_file(final_index_file_path) save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -353,7 +354,7 @@ def _load(name: str): # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def save_sharded_optimizer( @@ -399,7 +400,6 @@ def save_sharded_optimizer( use_zero=self.use_zero, dp_group=self.dp_group, tp_group=self.tp_group, - master_to_working_map=optimizer.get_master_to_working_map(), size_per_shard=size_per_shard, ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) @@ -424,7 +424,7 @@ def save_sharded_optimizer( # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -484,7 +484,7 @@ def save_sharded_optimizer( final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -579,24 +579,196 @@ def _get_param_id_from_optimizer_param( optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model state dict to a single file with given checkpointing path. + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. + gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if self.dp_rank != 0: + return + + # The logic of collecting parameter shards along tp degree + # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. + state_dict = model.state_dict() + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + if self.tp_rank == 0: + save_state_dict(state_dict, checkpoint, use_safetensors) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + state_dict_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(state_dict_list, state_dict, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + complete_state_dict = dict() + for _state_dict in state_dict_list: + complete_state_dict.update(_state_dict) + save_state_dict(complete_state_dict, checkpoint, use_safetensors) + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + strict = False + model_before_wrapping = model + model = model.unwrap() + + # Load from checkpoint. Since the logic of breaking parameter shards along tp degree + # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, + # model.load_state_dict can be directly called. + state_dict = load_state_dict(checkpoint) + model.load_state_dict(state_dict, strict=strict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer state dict to a file with given path. + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. + checkpoint (str): Path to save optimizer state_dict. + gather_dtensor (bool): Whether to gather_dtensor, not used. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + + # optimizer states of parameters kept by local device('s pipeline stage) + local_states = dict() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + # working param is needed for obtaining correct param_id + master_to_working_map = optimizer.get_master_to_working_map() + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + # gather complete state from tp shards & dp shards + param_id = optimizer.param_info["param2id"][id(working_param)] + original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=self.dp_group, + tp_group=self.tp_group, + use_zero=self.use_zero, + inplace=False, + device=torch.device("cuda"), + ) + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + states_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(states_list, local_states, self.pp_group) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + # Only the master rank do the saving. + if self.coordinator.is_master(): + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + for _states in states_list: + state_dict["state"].update(_states) + save_state_dict(state_dict, checkpoint, use_safetensors=False) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from a file with given path. - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + """ + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + # Complete optimizer state_dict loaded from checkpoint, need to be processed later. + state_dict = load_state_dict(checkpoint) + + # Load param_groups. + updated_groups = [] + saved_groups = state_dict["param_groups"] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. + master_to_working_map = optimizer.get_master_to_working_map() + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + if param is None: + continue + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ @@ -614,6 +786,7 @@ def gather_from_sharded_optimizer_state( tp_group: ProcessGroup, use_zero: bool, inplace: bool, + device: torch.device = torch.device("cpu"), ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -626,6 +799,7 @@ def gather_from_sharded_optimizer_state( tp_group (ProcessGroup): The process group of tensor parallel. use_zero (bool): Whether Zero is used. inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). Returns: OrderedDict: The complete optimizer state of given parameter. @@ -651,7 +825,7 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) - state_[k] = v.detach().clone().cpu() + state_[k] = v.detach().clone().to(device) return state_ diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 57fa813436da..a3df44fc6780 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -74,8 +74,6 @@ This plugin implements the combination of various parallel training strategies a > ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer. -> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release. - {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} ### Torch DDP Plugin diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index d4ef7012ff67..8d8a288da949 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -71,8 +71,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 > ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。 -> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。 - {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} ### Torch DDP 插件 diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index e8bb8f9e3475..711bd4d214a8 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -20,9 +20,8 @@ from tests.kit.model_zoo import model_zoo -# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() -@parameterize("shard", [True]) +@parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) @parameterize( From bd014673b07fdc561be8c84fe78e085f9af1897c Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 26 Sep 2023 10:58:05 +0800 Subject: [PATCH 46/58] update readme --- applications/Colossal-LLaMA-2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index f0a027d831a3..3470e8494ca0 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -73,7 +73,7 @@ The generation config for all dataset is greedy search. > > For other models and other dataset, we calculate logits over "A", "B", "C" and "D". -❗️ More details of the evaluation methods and reproduction of the results, please refer to [TODO: ColossalEval](). +❗️ More details of the evaluation methods and reproduction of the results, please refer to [ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval). ### Examples | Question Type | Question |
          Colossal-LLaMA-2-7b-base
          | From 4965c0dabd2df118a0ae33fd6f291fab617600a3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 26 Sep 2023 11:04:11 +0800 Subject: [PATCH 47/58] [lazy] support from_pretrained (#4801) * [lazy] patch from pretrained * [lazy] fix from pretrained and add tests * [devops] update ci --- .github/workflows/build_on_pr.yml | 3 +- .github/workflows/build_on_schedule.yml | 3 +- .../compatiblity_test_on_dispatch.yml | 3 +- .github/workflows/compatiblity_test_on_pr.yml | 3 +- .../compatiblity_test_on_schedule.yml | 3 +- colossalai/booster/booster.py | 8 + colossalai/interface/pretrained.py | 16 + colossalai/lazy/lazy_init.py | 3 + colossalai/lazy/pretrained.py | 309 ++++++++++++++++++ .../test_gemini_checkpoint_io.py | 20 ++ tests/test_lazy/test_from_pretrained.py | 31 ++ 11 files changed, 397 insertions(+), 5 deletions(-) create mode 100644 colossalai/interface/pretrained.py create mode 100644 colossalai/lazy/pretrained.py create mode 100644 tests/test_lazy/test_from_pretrained.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 291d6adac2b2..e2114d43bcd0 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -141,7 +141,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 60 defaults: run: @@ -214,6 +214,7 @@ jobs: NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt + LLAMA_PATH: /data/scratch/llama-tiny - name: Store Testmon Cache run: | diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 03b47e6cb5b6..6c77377be34f 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -13,7 +13,7 @@ jobs: runs-on: [self-hosted, 8-gpu] container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 40 steps: - name: Check GPU Availability # ensure all GPUs have enough memory @@ -64,6 +64,7 @@ jobs: env: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 2f03c8ced98d..5083212993cc 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -50,7 +50,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -92,3 +92,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index a621c7e3427d..cc17c66f9c3a 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -41,7 +41,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} @@ -87,3 +87,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 9933224f5675..158fe751bf2e 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -38,7 +38,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -85,6 +85,7 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 8d6b0b42e545..d73bc5babd80 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -131,6 +132,7 @@ def boost( """ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-dataloader case + pretrained_path = pretrained_utils.get_pretrained_path(model) # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( @@ -146,6 +148,12 @@ def boost( # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + if pretrained_path: + self.load_model(model, pretrained_path) + # clear pretrained path attr + orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model + pretrained_utils.set_pretrained_path(orig_model, None) + return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: diff --git a/colossalai/interface/pretrained.py b/colossalai/interface/pretrained.py new file mode 100644 index 000000000000..2f6bc10cd132 --- /dev/null +++ b/colossalai/interface/pretrained.py @@ -0,0 +1,16 @@ +from typing import Optional + +from torch.nn import Module + +__all__ = [ + "get_pretrained_path", + "set_pretrained_path", +] + + +def get_pretrained_path(model: Module) -> Optional[str]: + return getattr(model, "_pretrained", None) + + +def set_pretrained_path(model: Module, path: str) -> None: + setattr(model, "_pretrained", path) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index f29e997da495..a03334b28245 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -11,6 +11,7 @@ from colossalai.logging import get_dist_logger from .construction import ConstructorManager +from .pretrained import PretrainedManager import colossalai._analyzer._subclasses._meta_registration # noqa @@ -595,11 +596,13 @@ def wrapper(*args, **kwargs): ) ConstructorManager.apply(overrides) + PretrainedManager.inject() def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False ConstructorManager.clear() + PretrainedManager.recover() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py new file mode 100644 index 000000000000..21d44d4244d3 --- /dev/null +++ b/colossalai/lazy/pretrained.py @@ -0,0 +1,309 @@ +import os +from typing import Callable, Optional, Union + +import torch +from torch.nn import Module + +from colossalai.interface import pretrained as pretrained_interface + + +class PretrainedManager: + old_from_pretrained: Optional[Callable] = None + + @staticmethod + def inject() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # recover bound method to plain function + PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__ + PreTrainedModel.from_pretrained = new_from_pretrained + + @staticmethod + def recover() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # convert plain function to class method + PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained) + PretrainedManager.old_from_pretrained = None + + +@classmethod +def new_from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs +) -> Module: + from transformers import GenerationConfig + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import ( + ContextManagers, + _add_variant, + cached_file, + download_url, + has_file, + is_offline_mode, + is_remote_url, + no_init_weights, + ) + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_safetensors_available, + logging, + ) + + logger = logging.get_logger(__name__) + + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + if len(kwargs) > 0: + logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}") + + from_pt = True + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + model_kwargs = kwargs + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + elif use_safetensors: + raise EnvironmentError( + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "use_auth_token": use_auth_token, + } + if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}" + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + if from_pt: + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + dtype_orig = None + + if torch_dtype is not None: + if not isinstance(torch_dtype, torch.dtype): + raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}") + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + + with ContextManagers(init_contexts): + model = cls(config, *model_args, **model_kwargs) + + if from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except (OSError, TypeError): + logger.info("Generation config file not found, using a generation config created from the model config.") + + # set pretrained path + if resolved_archive_file: + pretrained_interface.set_pretrained_path(model, resolved_archive_file) + + return model diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index d66dec113017..634e81bb225d 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -3,11 +3,13 @@ import pytest import torch import torch.distributed as dist +from transformers import LlamaForCausalLM from utils import shared_tempdir import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import ( check_state_dict_equal, @@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) +def exam_lazy_from_pretrained(): + llama_path = os.environ["LLAMA_PATH"] + plugin = GeminiPlugin() + booster = Booster(plugin=plugin) + orig_model = LlamaForCausalLM.from_pretrained(llama_path) + orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()} + with LazyInitContext(): + model = LlamaForCausalLM.from_pretrained(llama_path) + model, *_ = booster.boost(model) + with shared_tempdir() as tempdir: + save_path = os.path.join(tempdir, "model.pt") + booster.save_model(model, save_path, shard=False) + dist.barrier() + state_dict = torch.load(save_path, map_location="cpu") + check_state_dict_equal(state_dict, orig_state_dict, False) + + def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() + exam_lazy_from_pretrained() @pytest.mark.dist diff --git a/tests/test_lazy/test_from_pretrained.py b/tests/test_lazy/test_from_pretrained.py new file mode 100644 index 000000000000..623dd82c5ad9 --- /dev/null +++ b/tests/test_lazy/test_from_pretrained.py @@ -0,0 +1,31 @@ +import os + +from transformers import BertForPreTraining, LlamaForCausalLM + +import colossalai.interface.pretrained as pretrained_utils +from colossalai.lazy import LazyInitContext + + +def test_lazy_from_pretrained(): + # test from cached file, unsharded + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + with LazyInitContext(): + deffered_model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + # test from local file, sharded + llama_path = os.environ["LLAMA_PATH"] + model = LlamaForCausalLM.from_pretrained(llama_path) + with LazyInitContext(): + deffered_model = LlamaForCausalLM.from_pretrained(llama_path) + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + +if __name__ == "__main__": + test_lazy_from_pretrained() From 8cbce6184d831a6d58761ad4a46e6c28137b8047 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 26 Sep 2023 11:36:53 +0800 Subject: [PATCH 48/58] update --- applications/Colossal-LLaMA-2/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index 3470e8494ca0..71d1c7bcdb02 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -32,6 +32,10 @@ The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced t Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others. +❗️**Important notice**: +* All training data used for this project is collected from well-known public dataset. +* We do not use any testing data from the evaluation benchmarks for training. + ### Performance Evaluation We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models. From b6cf0aca5520d02fafb1c319c3b8765012d2d48d Mon Sep 17 00:00:00 2001 From: Chandler-Bing Date: Tue, 26 Sep 2023 11:44:27 +0800 Subject: [PATCH 49/58] [hotfix] change llama2 Colossal-LLaMA-2 script filename (#4800) change filename: pretraining.py -> trainin.py there is no file named pretraing.py. wrong writing --- applications/Colossal-LLaMA-2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index 71d1c7bcdb02..95253eea7ae7 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -225,7 +225,7 @@ Here is details about CLI arguments: You can use `colossalai run` to launch multi-nodes training: ```bash colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ -pretrain.py --OTHER_CONFIGURATIONS +train.py --OTHER_CONFIGURATIONS ``` Here is a sample hostfile: ```bash From a22706337a57dd1c98b95739dd09d98bd55947a0 Mon Sep 17 00:00:00 2001 From: Yan haixu <40758050+hova88@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:43:46 +0800 Subject: [PATCH 50/58] [misc] add last_epoch in CosineAnnealingWarmupLR (#4778) --- colossalai/nn/lr_scheduler/cosine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index a896d3acba6c..f563825de0d5 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -62,7 +62,7 @@ def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: base_scheduler = _CosineAnnealingLR( optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch ) - super().__init__(optimizer, warmup_steps, base_scheduler) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) class FlatAnnealingLR(DelayerScheduler): From da15fdb9caa05904eb27844e69af4b76af2aca46 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 27 Sep 2023 10:24:04 +0800 Subject: [PATCH 51/58] [doc] add lazy init docs (#4808) --- colossalai/lazy/lazy_init.py | 29 ++------- docs/sidebars.json | 1 + docs/source/en/basics/booster_api.md | 2 + docs/source/en/features/lazy_init.md | 76 +++++++++++++++++++++++ docs/source/zh-Hans/basics/booster_api.md | 2 + docs/source/zh-Hans/features/lazy_init.md | 76 +++++++++++++++++++++++ 6 files changed, 162 insertions(+), 24 deletions(-) create mode 100644 docs/source/en/features/lazy_init.md create mode 100644 docs/source/zh-Hans/features/lazy_init.md diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index a03334b28245..b130111ba3d9 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -472,30 +472,11 @@ def __rpow__(self, other): class LazyInitContext: """Context manager for lazy initialization. Enables initializing the model without allocating real memory. - Usage: - 1. The model is initialized, but no real memory is allocated. - >>> ctx = LazyInitContext() - >>> with ctx: - >>> model = MyModel().cuda() - - 2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated. - >>> with ctx.traceable(model): - >>> gm = symbolic_trace(model, meta_args=meta_args) - >>> # Solve the execution strategy and apply the strategy to the model - >>> strategy = StrategyAndSpec() - - 3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device) - >>> model = ctx.materialize(model) - - 3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario) - >>> model = apply_strategy_to_all_params(model, strategy) - >>> model = ctx.distribute(model) - - Warnings: - This API is still experimental and further modifications can be made to it. - For example: - 1. Quantization strategies can be applied before allocating real memory. - 2. Lazy initialization seems slower than normal initialization. + Args: + tensor_cls (Union[_MyTensor, LazyTensor], optional): This is only for test. Defaults to LazyTensor. + default_device (Optional[Union[torch.device, str, int]], optional): Defalt device for initialization. + If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu. + Defaults to None. """ _replaced: bool = False diff --git a/docs/sidebars.json b/docs/sidebars.json index ce197a31e71b..45e86afc1f61 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -55,6 +55,7 @@ }, "features/pipeline_parallel", "features/nvme_offload", + "features/lazy_init", "features/cluster_utils" ] }, diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 392251ef06b2..4d7ffe5a4cbf 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -32,6 +32,8 @@ Plugin is an important component that manages parallel configuration (eg: The ge More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md). +Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). + ### API of booster {{ autodoc:colossalai.booster.Booster }} diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md new file mode 100644 index 000000000000..133fd799280a --- /dev/null +++ b/docs/source/en/features/lazy_init.md @@ -0,0 +1,76 @@ +# Lazy initialization + +Author: [Hongxiu Liu](https://github.com/ver217) + +**Prerequisite:** +- [Train with booster](../basics/booster_api.md) + +## Introduction + +Lazy initialization defers model initialization. It saves memory when initializing large models. + +If your model has `N` billion parameters and your memory (or GPU memory) is `M` GB, we recommend you use lazy initialization when `4N >= M`. Otherwise, it is optional. + +## Usage + +Lazy initialization must be used with booster. + +### API reference + +{{ autodoc:colossalai.lazy.LazyInitContext }} + +### Example + +```python +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin + +from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining + +colossalai.launch({}) +plugin = GeminiPlugin() +booster = Booster(plugin) + +# 1. Initialize model from scratch +# Initialization on cuda will accelerate the initialization process but take more GPU memory. +with LazyInitContext(default_device="cuda"): + model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4)) +model, *_ = booster.boost(model) + +# 2. Initialize model from pretrained +with LazyInitContext(): + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") +model, *_ = booster.boost(model) +``` + +> ⚠️ Lazy initialization from pretrained is supported for colossalai>0.3.3 or main branch. + +## Limitations + +As we claimed, lazy initialization must be used with booster. And only several plugins support it. + +| Plugin | Supported | Remarks | +|-----------------|-----------|--------------| +| Gemini | Yes | | +| Hybrid Parallel | Yes | | +| Low Level Zero | No | No need | +| Torch DDP | No | Incompatible | +| Torch FSDP | No | Incompatible | + +Not all models can be lazily initialized. In some cases, a part of parameters/buffers may be early initialized. But don't worry, this part usually takes a small proportion of the whole model. + +And some models are not supported at all which will raise an error. We tested models in torchvision, diffusers, timm, transformers, torchaudio and torchrec. Below models are not supported: + +| Model | Category | +|-------------------------------|--------------| +| wav2vec2_base | torchaudio | +| hubert_base | torchaudio | +| ViTModel | transformers | +| ViTForMaskedImageModeling | transformers | +| ViTForImageClassification | transformers | +| Blip2Model | transformers | +| Blip2ForConditionalGeneration | transformers | + + diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md index c59d75d321c0..f9310374d823 100644 --- a/docs/source/zh-Hans/basics/booster_api.md +++ b/docs/source/zh-Hans/basics/booster_api.md @@ -35,6 +35,8 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。 +有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考[懒惰初始化](../features/lazy_init.md)。 + ### Booster 接口 diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md new file mode 100644 index 000000000000..80742a56df29 --- /dev/null +++ b/docs/source/zh-Hans/features/lazy_init.md @@ -0,0 +1,76 @@ +# 懒惰初始化 + +作者: [Hongxiu Liu](https://github.com/ver217) + +**前置教程:** +- [Train with booster](../basics/booster_api.md) + +## 简介 + +懒惰初始化延迟了模型的初始化。它能够节省在大模型初始化时的内存占用。 + +如果你的模型有 `N` 十亿个参数并且你的内存(或显存)为 `M` GB, 我们推荐您在 `4N >= M` 时使用懒惰初始化。否则,懒惰初始化不是必须的。 + +## 使用 + +懒惰初始化必须与 booster 一起使用。 + +### API 参考 + +{{ autodoc:colossalai.lazy.LazyInitContext }} + +### 例子 + +```python +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin + +from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining + +colossalai.launch({}) +plugin = GeminiPlugin() +booster = Booster(plugin) + +# 1. Initialize model from scratch +# Initialization on cuda will accelerate the initialization process but take more GPU memory. +with LazyInitContext(default_device="cuda"): + model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4)) +model, *_ = booster.boost(model) + +# 2. Initialize model from pretrained +with LazyInitContext(): + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") +model, *_ = booster.boost(model) +``` + +> ⚠️ 使用懒惰初始化加载预训练模型在 colossalai>0.3.3 或主分支上支持。 + +## 限制 + +我们提到,懒惰初始化必须与 booster 一起使用。只有几个插件支持它。 + +| 插件 | 支持情况 | 备注 | +|-----------------|---------|--------| +| Gemini | 是 | | +| Hybrid Parallel | 是 | | +| Low Level Zero | 否 | 不需要 | +| Torch DDP | 否 | 不兼容 | +| Torch FSDP | 否 | 不兼容 | + +不是所有的模型都可以懒惰初始化。在某些情况下,一部分参数/缓冲区可能会被提前初始化。但是不用担心,这部分通常只占整个模型的一小部分。 + +并且一些模型完全不支持,会引发错误。我们测试了 torchvision, diffusers, timm, transformers, torchaudio 和 torchrec 中的模型。以下模型不受支持: + +| 模型 | 分类 | +|-------------------------------|--------------| +| wav2vec2_base | torchaudio | +| hubert_base | torchaudio | +| ViTModel | transformers | +| ViTForMaskedImageModeling | transformers | +| ViTForImageClassification | transformers | +| Blip2Model | transformers | +| Blip2ForConditionalGeneration | transformers | + + From 54b3ad89249a23221dda036ed598dade06808011 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Wed, 27 Sep 2023 10:35:24 +0800 Subject: [PATCH 52/58] [hotfix] fix norm type error in zero optimizer (#4795) --- colossalai/zero/low_level/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index ba1135940df0..0a15f8ddd718 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro else: total_norm = 0.0 for g in gradients: - param_norm = g.data.double().norm(2) - total_norm += param_norm.item() ** 2 + param_norm = g.data.double().norm(norm_type) + total_norm += param_norm.item() ** norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) From 11f1e426fe2b549f8745d5036d4e20bc3dc411ed Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Wed, 27 Sep 2023 10:43:03 +0800 Subject: [PATCH 53/58] [hotfix] Correct several erroneous code comments (#4794) --- colossalai/shardformer/policies/base_policy.py | 2 +- colossalai/zero/low_level/bookkeeping/bucket_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129a00..eb03500531bc 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -50,7 +50,7 @@ def example_replace_weight(module: torch.nn.Module): new_weight = shard_rowwise(weight, process_group) module.weight = torch.nn.Parameter(new_weight) ``` - sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2a75d704711a..2828d517573d 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -92,7 +92,7 @@ def get_grad(self) -> Dict: def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: - [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: Tensor: the flattened gradients slices in the bucket From fb46d05cdfd75d881d18675a382a19c4fbca6bd0 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:45:03 +0800 Subject: [PATCH 54/58] [format] applied code formatting on changed files in pull request 4595 (#4602) Co-authored-by: github-actions From bbbcac26e80601790728b1e8b8a7595d4d89a7b4 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 27 Sep 2023 12:50:22 +0800 Subject: [PATCH 55/58] fix format (#4815) --- op_builder/gptq.py | 48 +++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/op_builder/gptq.py b/op_builder/gptq.py index 012cf0f8a78d..bc4f445de067 100644 --- a/op_builder/gptq.py +++ b/op_builder/gptq.py @@ -1,19 +1,17 @@ -import os -import torch import re +import torch + from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag +from .utils import append_nvcc_threads -class GPTQBuilder(Builder): +class GPTQBuilder(Builder): NAME = "cu_gptq" PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" def __init__(self): - super().__init__(name=GPTQBuilder.NAME, - prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) - + super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] @@ -21,32 +19,38 @@ def include_dirs(self): def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'gptq/linear_gptq.cpp', - 'gptq/column_remap.cu', - 'gptq/cuda_buffers.cu', - 'gptq/q4_matmul.cu', - 'gptq/q4_matrix.cu' + self.csrc_abs_path(fname) + for fname in [ + "gptq/linear_gptq.cpp", + "gptq/column_remap.cu", + "gptq/cuda_buffers.cu", + "gptq/q4_matmul.cu", + "gptq/q4_matrix.cu", ] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-v', - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17" + extra_cuda_flags = [ + "-v", + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + "-lcublas", + "-std=c++17", ] - for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) + res = re.search(r"sm_(\d+)", arch) if res: arch_cap = res[1] if int(arch_cap) >= 80: - extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) \ No newline at end of file + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) From be400a09364942a02e0048d940f3699380cd6f1f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 27 Sep 2023 13:15:32 +0800 Subject: [PATCH 56/58] [chat] fix gemini strategy (#4698) * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.py --- .../benchmarks/benchmark_opt_lora_dummy.py | 4 ++-- applications/Chat/coati/models/base/actor.py | 1 + applications/Chat/coati/ray/utils.py | 4 ++-- .../Chat/coati/trainer/strategies/base.py | 4 ++-- .../coati/trainer/strategies/colossalai.py | 9 +++++++-- .../Chat/coati/trainer/strategies/ddp.py | 10 +++++----- .../community/peft/train_peft_prompts.py | 2 +- .../examples/community/peft/train_peft_sft.py | 2 +- applications/Chat/examples/requirements.txt | 2 +- applications/Chat/examples/train_prompts.py | 8 ++++++-- .../Chat/examples/train_reward_model.py | 8 +++++++- applications/Chat/examples/train_sft.py | 7 ++++--- applications/Chat/requirements-test.txt | 2 +- applications/Chat/requirements.txt | 2 +- applications/Chat/tests/test_checkpoint.py | 4 ++-- applications/Chat/tests/test_train.sh | 20 ++++++------------- 16 files changed, 49 insertions(+), 40 deletions(-) diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index bee5c8d3faf3..0d0e2a7d34f5 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -76,9 +76,9 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5) elif args.strategy == "colossalai_gemini_cpu": - strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif args.strategy == "colossalai_zero2_cpu": diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 0634631df7a3..8b2b81ed071c 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -30,3 +30,4 @@ def forward( """Returns model output.""" output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output + diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 799b2af8f982..b88140c0e036 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str): if strategy == "ddp": strategy_ = DDPStrategy() elif strategy == "colossalai_gemini": - strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5) elif strategy == "colossalai_zero2": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif strategy == "colossalai_gemini_cpu": - strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) elif strategy == "colossalai_zero2_cpu": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index 303d4bc220a6..a78716216ae0 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -110,8 +110,8 @@ def unwrap_model(model: nn.Module) -> nn.Module: """ return model - def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None: - self.booster.save_model(model, path, shard=not only_rank0, **kwargs) + def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None: + self.booster.save_model(model, path, shard=shard, **kwargs) def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None: self.booster.load_model(model, path, strict) diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 3018ca43061e..7129edb060ef 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -6,7 +6,6 @@ import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.lazy.lazy_init import LazyInitContext from colossalai.utils import get_current_device from colossalai.zero.gemini.gemini_ddp import GeminiDDP @@ -130,6 +129,9 @@ def __init__( seed: int = 42, shard_init: bool = False, # only for stage 3 placement_policy: str = "auto", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement pin_memory: bool = True, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3 search_range_m: int = 32, # only for stage 3 @@ -160,6 +162,9 @@ def __init__( plugin_initializer = lambda: GeminiPlugin( chunk_init_device=get_current_device(), placement_policy=placement_policy, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, precision="fp16", pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, @@ -188,7 +193,7 @@ def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) def model_init_context(self): - return LazyInitContext(default_device=get_current_device()) + return super().model_init_context() def unwrap_model(self, model: nn.Module) -> nn.Module: assert isinstance(model, GeminiDDP) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index 66ff6703da4d..f2a44aeb0961 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -87,9 +87,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module: return model.unwrap() def save_pretrained( - self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None ) -> None: - if not only_rank0 or dist.get_rank() == 0: + if dist.get_rank() == 0: unwrapped_model = self.unwrap_model(model) assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) pretrained_model = unwrapped_model.model @@ -98,19 +98,19 @@ def save_pretrained( pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None) if tokenizer is not None: tokenizer.save_pretrained(path) - model_path = os.path.join(path, "pytorch_model.bin") - self.save_model(model, model_path, only_rank0=only_rank0) + model_path = os.path.join(path, "pytorch_model.bin") + self.save_model(model, model_path, shard=shard) def _replace_keys(model_path: str, replace_fn: Callable): state_dict = torch.load(model_path, map_location="cpu") state_dict = {replace_fn(k): v for k, v in state_dict.items()} torch.save(state_dict, model_path) - # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin # HACK: rename keys of pytorch_model.bin if dist.get_rank() == 0: _replace_keys(model_path, lambda k: k.replace("model.", "", 1)) + def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy model = self.unwrap_model(model) diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index e49db1d2bc1b..99a024f1463c 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -24,7 +24,7 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py index 0b62dd652adb..3bbef7208374 100644 --- a/applications/Chat/examples/community/peft/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -24,7 +24,7 @@ def train(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda") + strategy = GeminiStrategy(placement_policy="static") elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index a7cfb5da7fe1..5474dfa16b3e 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,3 +1,3 @@ pandas>=1.4.1 sentencepiece -colossalai>=0.3.1 +colossalai==0.3.3 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index a8ab15eebfa5..8868e278d85e 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -23,7 +23,7 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: @@ -33,6 +33,10 @@ def main(args): warnings.warn("LoRA weights should be merged with the model weights") state_dict = torch.load(args.rm_path, map_location="cpu") + if args.lora_rank > 0: + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 + with strategy.model_init_context(): # configure model if args.model == "gpt2": @@ -199,7 +203,7 @@ def main(args): LORA_MANAGER.merge_weights = True actor.eval() # save model checkpoint after fitting - strategy.save_model(actor, args.save_path, only_rank0=True) + strategy.save_pretrained(actor, path=args.save_path) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: strategy.save_optimizer( diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index c1be51f2f587..df6e8b6bdc26 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -1,4 +1,5 @@ import argparse +import warnings import torch import torch.distributed as dist @@ -33,6 +34,10 @@ def train(args): raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model + if args.lora_rank > 0: + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 + with strategy.model_init_context(): if args.model == "bloom": model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) @@ -165,7 +170,8 @@ def train(args): LORA_MANAGER.merge_weights = True model.eval() # save model checkpoint after fitting on only rank0 - strategy.save_model(model, args.save_path, only_rank0=True) + state_dict = model.state_dict() + torch.save(state_dict, args.save_path) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: strategy.save_optimizer( diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 4f36791be3cf..66d08da30120 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -40,8 +40,9 @@ def train(args): # configure model if args.lora_rank > 0: - warnings.warn("Gradient checkpoint is disabled when using LoRA") - args.grad_checkpoint = False + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 + with strategy.model_init_context(): if args.model == "bloom": model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) @@ -184,7 +185,7 @@ def train(args): LORA_MANAGER.merge_weights = True model.eval() # save model checkpoint after fitting on only rank0 - strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) + strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: strategy.save_optimizer( diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index adf2cc1bf545..93d48bcb6f79 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1,2 +1,2 @@ pytest -colossalai>=0.3.1 +colossalai==0.3.3 diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index a784ccbe0d3a..e56aaca0e7cb 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai>=0.3.1 +colossalai==0.3.3 torch<2.0.0, >=1.12.1 langchain tokenizers diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 9dfaa7c88206..9c08aa36c9b4 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -57,9 +57,9 @@ def run_test_checkpoint(strategy_name: str, shard: bool): rank0_dirname = rank0_dirname[0] model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt") - strategy.save_model(actor, model_path, only_rank0=not shard) + strategy.save_model(actor, model_path) optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt") - strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard) + strategy.save_optimizer(actor_optim, optim_path) dist.barrier() strategy.load_model(actor, model_path, strict=False) diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh index 55de269005ed..68fca7fbf8c0 100755 --- a/applications/Chat/tests/test_train.sh +++ b/applications/Chat/tests/test_train.sh @@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config MODELS=('gpt2' 'bloom' 'opt' 'llama') STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2') + export OMP_NUM_THREADS=8 # install requirements @@ -80,13 +81,10 @@ SKIPPED_TESTS=( "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" - "gpt2-colossalai_gemini" - "opt-colossalai_gemini" - "bloom-colossalai_gemini" ) GRAD_CKPTS=('' '--grad_checkpoint') -for lora_rank in '0' '4'; do +for lora_rank in '0'; do for model in ${MODELS[@]}; do strategies=($(shuf -e "${STRATEGIES[@]}")) for strategy in ${strategies[@]}; do @@ -135,14 +133,11 @@ SKIPPED_TESTS=( "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" - "gpt2-colossalai_gemini" - "opt-colossalai_gemini" - "bloom-colossalai_gemini" ) LOSS_FNS=('log_sig' 'log_exp') DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static') -for lora_rank in '0' '4'; do +for lora_rank in '0'; do for model in ${MODELS[@]}; do strategies=($(shuf -e "${STRATEGIES[@]}")) for strategy in ${strategies[@]}; do @@ -193,13 +188,10 @@ SKIPPED_TESTS=( "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2" - "gpt2-colossalai_gemini" - "opt-colossalai_gemini" - "bloom-colossalai_gemini" ) for model in ${MODELS[@]}; do - for lora_rank in '0' '4'; do + for lora_rank in '0'; do strategies=($(shuf -e "${STRATEGIES[@]}")) for strategy in ${strategies[@]}; do if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then @@ -223,7 +215,7 @@ for model in ${MODELS[@]}; do --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ - --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt + --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts passed=$? if [ $passed -eq 0 ]; then break @@ -238,4 +230,4 @@ for model in ${MODELS[@]}; do rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt done done -rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt +rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts From 1fa8c5e09ff7422c30fe7683beb209bfba7e153b Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:33:54 +0800 Subject: [PATCH 57/58] Update Qwen-7B results (#4821) Co-authored-by: Xu Yuanchen --- applications/Colossal-LLaMA-2/README.md | 26 ++++++++++++------------- applications/ColossalEval/README.md | 22 +++++++++++++-------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index 95253eea7ae7..34967c04360c 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -30,10 +30,10 @@ ## Colossal-LLaMA-2-7B The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks. -Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others. +Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others. -❗️**Important notice**: -* All training data used for this project is collected from well-known public dataset. +❗️**Important notice**: +* All training data used for this project is collected from well-known public dataset. * We do not use any testing data from the evaluation benchmarks for training. ### Performance Evaluation @@ -56,7 +56,7 @@ The generation config for all dataset is greedy search. | ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | | ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | | InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | -| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | | | | | | | | | | | | Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | | Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | @@ -91,7 +91,7 @@ The generation config for all dataset is greedy search. | Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | | Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | -❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md). +❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md). ### Training Logs We also recorded the training logs for the experiment @@ -168,7 +168,7 @@ python colossal_llama2/tokenizer/init_tokenizer.py \ Here is details about CLI arguments: * Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`. * Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer. -* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer. +* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer. #### 2. Init Model Preparation Initialize the new model checkpoint by calculating the mean values from the original model checkpoint. @@ -191,7 +191,7 @@ Here is details about CLI arguments: #### 3. Data Preparation Raw data should be formatted as `jsonl` format. Each data point should have the following fields: * `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty. -* `target` (str, compulsory): Loss will be calculated. +* `target` (str, compulsory): Loss will be calculated. * `category` (str, compulsory): Tags for each data point. Examples: @@ -226,7 +226,7 @@ You can use `colossalai run` to launch multi-nodes training: ```bash colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ train.py --OTHER_CONFIGURATIONS -``` +``` Here is a sample hostfile: ```bash hostname1 @@ -240,7 +240,7 @@ Here is details about CLI arguments: * Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format. * Dataset path: `--dataset`. Path to the pre-tokenized dataset. * Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). -* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. +* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. * Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. * Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. * Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs. @@ -334,7 +334,7 @@ To balance both sides, we finally construct our vocabulary with size 69,104. The ### Training Strategy #### Multi-stage Training -In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages. +In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages. Therefore, we have divided the training process into three stages: * Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens. @@ -343,7 +343,7 @@ Therefore, we have divided the training process into three stages: Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks. -The following figure illustrates the three stages for training Colossal-LLaMA-2. +The following figure illustrates the three stages for training Colossal-LLaMA-2.

          @@ -372,7 +372,7 @@ Applying the above process to perform knowledge transfer in any field allows for ``` ```bibtex @misc{touvron2023llama, - title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, + title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom}, year={2023}, eprint={2307.09288}, @@ -388,5 +388,3 @@ Applying the above process to perform knowledge transfer in any field allows for } } ``` - - diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index 06c6962f7978..3f645fe7892c 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -1,4 +1,8 @@ -# ColossalEval +

          +

          + +

          +
          ## Table of Contents @@ -57,7 +61,9 @@ More details about metrics can be found in [Metrics](#metrics). | ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | | ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | | InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | -| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| InternLM-20B | - | 2.3T | | 60.96 (62.05) | 59.08 (-) | 57.96 | 61.92 | - | +| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Qwen-7B | - | 2.4T | | 58.33 (58.20) | 62.54 (62.20) | 64.34 | 74.05 | 63.50 | | | | | | | | | | | | Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | | Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | @@ -74,7 +80,7 @@ More details about metrics can be found in [Metrics](#metrics). > > We use zero-shot for ChatGLM models. > -> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. +> To evaluate Qwen-7B on dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Both the original and updated versions of Qwen-7B tend to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. > > For other models and other dataset, we calculate logits over "A", "B", "C" and "D". @@ -185,8 +191,8 @@ Example: In this step, you will configure your tokenizer and model arguments to infer on the given datasets. A config file consists of two parts. -1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. -2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. +1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. +2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench and LongBench and few-shot on dataset MMLU, CMMLU and AGIEval. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`. Once you have all config ready, the program will run inference on all the given datasets on all the given models. @@ -253,7 +259,7 @@ In dataset evaluation, we calculate different metrics on the given inference res A config file for dataset evaluation consists of two parts. 1. Model config. In model config, you need to specify model name. If you want to evaluate perplexity over a pretrain dataset and calculate per-byte-perplexity, you have to add your tokenizer config and model max length. -2. Dataset config. In dataset config, you need to specify the evaluation arguments for the dataset. +2. Dataset config. In dataset config, you need to specify the evaluation metrics for the dataset. Once you have all config ready, the program will run evaluation on inference results for all given models and dataset. @@ -315,7 +321,7 @@ The following is an example of a English config file. The configuration file can ``` ##### How to Use -After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. +After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. The prompt files for battle and gpt evaluation can be found in `configs/gpt_evaluation/prompt`. `target file` is the path to the converted dataset you save during inference time. An example script is provided as follows: @@ -381,7 +387,7 @@ We provide 2 examples for you to explore our `colossal_eval` package. This example is in folder `examples/dataset_evaluation`. 1. `cd examples/dataset_evaluation` -2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters +2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. 3. Run `inference.sh` to get inference results. 4. Fill in your evaluation config file in `config/evaluation/config.json`. Set the model and dataset parameters. 5. Run `eval_dataset.sh` to get evaluation results. From 822051d8884a46d4d8626330e21adfd6427c99a0 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Wed, 27 Sep 2023 17:37:39 +0800 Subject: [PATCH 58/58] [doc] update slack link (#4823) --- .github/ISSUE_TEMPLATE/config.yml | 2 +- README.md | 2 +- applications/Chat/README.md | 2 +- colossalai/nn/optimizer/README.md | 2 +- docs/README-zh-Hans.md | 2 +- examples/README.md | 2 +- examples/images/diffusion/README.md | 2 +- examples/images/dreambooth/README.md | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index b310fcfefc15..436bdf887c69 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,7 +1,7 @@ blank_issues_enabled: true contact_links: - name: ❓ Simple question - Slack Chat - url: https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w + url: https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack about: This issue tracker is not for technical support. Please use our Slack chat, and ask the community for help. - name: ❓ Simple question - WeChat url: https://github.com/hpcaitech/ColossalAI/blob/main/docs/images/WeChat.png diff --git a/README.md b/README.md index a50cf496a98e..b2efb7910489 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) - [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) diff --git a/applications/Chat/README.md b/applications/Chat/README.md index 59e2c4548365..d5be04ab9f44 100644 --- a/applications/Chat/README.md +++ b/applications/Chat/README.md @@ -413,7 +413,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on - [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), + [Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index c4afc6128d43..e89e6217d596 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -18,7 +18,7 @@ quickly deploy large AI model training and inference, reducing large AI model tr [**Paper**](https://arxiv.org/abs/2110.14883) | [**Documentation**](https://www.colossalai.org/) | [**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | -[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) +[**Slack**](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) ## Table of Content diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 06977f9471c0..499d67a37c70 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -16,7 +16,7 @@ [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) - [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) | [English](README.md) | [中文](README-zh-Hans.md) | diff --git a/examples/README.md b/examples/README.md index 142a735c6819..b822fb8ff923 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,7 +36,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index b63896524909..d6a1c47d6b87 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -254,7 +254,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md index 4e9febbc5fa8..6716052897a6 100644 --- a/examples/images/dreambooth/README.md +++ b/examples/images/dreambooth/README.md @@ -139,7 +139,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com