diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 1c25e4c94f24..98897095753d 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # register hook to the parameters if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(param, comm_spec_to_use) + wrapper(param, comm_spec_to_use, reduction_stream) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -440,14 +441,15 @@ def hook_fn(grad): # register hook to the parameters if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(target, comm_spec_to_use) + wrapper(target, comm_spec_to_use, reduction_stream) return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index d3f9fd01d891..131c35156dcd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -483,4 +483,6 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li raise TypeError( f"Found unexpected output type {type(output)} from the recover method of BmmTransform") strategies = recovered_stragies + for index, strategies in enumerate(strategies): + strategies.name = f"{strategies.name}_{index}" return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index fa2246f952a9..9aa95b43a966 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -247,12 +247,12 @@ def collate_strategies(self) -> List[ShardingStrategy]: strategies.append(self.split_rhs_space_both_contract(1, 0)) # RR= RS x SR - strategies.append(self.recompute_split_both_contract(0)) - strategies.append(self.recompute_split_both_contract(1)) + # strategies.append(self.recompute_split_both_contract(0)) + # strategies.append(self.recompute_split_both_contract(1)) - # RS = RR x RS - strategies.append(self.split_rhs_space_only(0)) - strategies.append(self.split_rhs_space_only(1)) + # # RS = RR x RS + # strategies.append(self.split_rhs_space_only(0)) + # strategies.append(self.split_rhs_space_only(1)) # S01R = S01R x RR strategies.append(self.split_lhs_1st_dim_1d(0, 1)) @@ -263,8 +263,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) - # RR = RR x RR - strategies.append(self.non_split()) + # # RR = RR x RR + # strategies.append(self.non_split()) return strategies diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index b5a97eded90c..22a01dddb869 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -98,7 +98,7 @@ def flatten(self): return DeviceMesh(self.physical_mesh_id, tuple(flatten_mesh_shape), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), init_process_group=self.init_process_group, need_flatten=False) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 3c9e0fd56696..b31c06994190 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -463,7 +463,7 @@ def get_comm_cost(self): if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) # give a tiny cost to shard - backward_communication_cost = 10 + backward_communication_cost = 100 if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) @@ -481,13 +481,13 @@ def get_comm_cost(self): if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: # give a tiny cost to shard - forward_communication_cost = 10 + forward_communication_cost = 100 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: # no need for axis because all devices are used in mix_gather forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) - backward_communication_cost = 10 + backward_communication_cost = 100 if self.forward_only: cost_dict["forward"] = forward_communication_cost