From 60d104b667678afd1ad14826f8a2a2eaa2fa908c Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Wed, 18 Jan 2023 10:23:40 +0800 Subject: [PATCH 1/2] [autoparallel] accelerate gpt2 training --- .../passes/runtime_preparation_pass.py | 14 +++++++------ .../node_handler/matmul_handler.py | 2 ++ .../strategy/matmul_strategy_generator.py | 20 +++++++++---------- colossalai/device/device_mesh.py | 2 +- colossalai/tensor/comm_spec.py | 6 +++--- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 1c25e4c94f24..98897095753d 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # register hook to the parameters if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(param, comm_spec_to_use) + wrapper(param, comm_spec_to_use, reduction_stream) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -440,14 +441,15 @@ def hook_fn(grad): # register hook to the parameters if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(target, comm_spec_to_use) + wrapper(target, comm_spec_to_use, reduction_stream) return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index d3f9fd01d891..131c35156dcd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -483,4 +483,6 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li raise TypeError( f"Found unexpected output type {type(output)} from the recover method of BmmTransform") strategies = recovered_stragies + for index, strategies in enumerate(strategies): + strategies.name = f"{strategies.name}_{index}" return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index fa2246f952a9..5c72c2dbbb18 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -61,9 +61,9 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost / 1e5, + bwd=bwd_compute_cost / 1e5, + total=fwd_compute_cost / 1e5 + bwd_compute_cost / 1e5) return compute_cost @ignore_sharding_exception @@ -247,12 +247,12 @@ def collate_strategies(self) -> List[ShardingStrategy]: strategies.append(self.split_rhs_space_both_contract(1, 0)) # RR= RS x SR - strategies.append(self.recompute_split_both_contract(0)) - strategies.append(self.recompute_split_both_contract(1)) + # strategies.append(self.recompute_split_both_contract(0)) + # strategies.append(self.recompute_split_both_contract(1)) - # RS = RR x RS - strategies.append(self.split_rhs_space_only(0)) - strategies.append(self.split_rhs_space_only(1)) + # # RS = RR x RS + # strategies.append(self.split_rhs_space_only(0)) + # strategies.append(self.split_rhs_space_only(1)) # S01R = S01R x RR strategies.append(self.split_lhs_1st_dim_1d(0, 1)) @@ -263,8 +263,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) - # RR = RR x RR - strategies.append(self.non_split()) + # # RR = RR x RR + # strategies.append(self.non_split()) return strategies diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index b5a97eded90c..22a01dddb869 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -98,7 +98,7 @@ def flatten(self): return DeviceMesh(self.physical_mesh_id, tuple(flatten_mesh_shape), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), init_process_group=self.init_process_group, need_flatten=False) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 3c9e0fd56696..b31c06994190 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -463,7 +463,7 @@ def get_comm_cost(self): if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) # give a tiny cost to shard - backward_communication_cost = 10 + backward_communication_cost = 100 if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) @@ -481,13 +481,13 @@ def get_comm_cost(self): if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: # give a tiny cost to shard - forward_communication_cost = 10 + forward_communication_cost = 100 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: # no need for axis because all devices are used in mix_gather forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) - backward_communication_cost = 10 + backward_communication_cost = 100 if self.forward_only: cost_dict["forward"] = forward_communication_cost From 0044d28fe0916e8872bbb1e3eeb87c334851f61a Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Wed, 18 Jan 2023 12:27:13 +0800 Subject: [PATCH 2/2] polish --- .../node_handler/strategy/matmul_strategy_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 5c72c2dbbb18..9aa95b43a966 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -61,9 +61,9 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost / 1e5, - bwd=bwd_compute_cost / 1e5, - total=fwd_compute_cost / 1e5 + bwd_compute_cost / 1e5) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) return compute_cost @ignore_sharding_exception