From aedc5901588a3ddf8cb8338fb0f35334f6560b44 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Tue, 7 Feb 2023 11:31:03 +0800 Subject: [PATCH] add overlap option --- .../passes/runtime_preparation_pass.py | 29 ++++++++++++------- .../auto_parallel/tensor_shard/initialize.py | 19 ++++++++---- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 98897095753d..897602ce1d24 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -352,7 +352,7 @@ def _process_sharding_spec(sharding_spec): return gm -def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): +def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): """ Apply the sharding action to the module parameters and buffers following the instructions of solver solution. @@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # register hook to the parameters if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec, stream): + def wrapper(param, comm_spec, stream, overlap): def hook_fn(grad): - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) + if overlap: + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) + else: + _all_reduce(grad, comm_spec, async_op=False) param.register_hook(hook_fn) - wrapper(param, comm_spec_to_use, reduction_stream) + wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -441,15 +444,18 @@ def hook_fn(grad): # register hook to the parameters if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec, stream): + def wrapper(param, comm_spec, stream, overlap): def hook_fn(grad): - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) + if overlap: + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) + else: + _all_reduce(grad, comm_spec, async_op=False) param.register_hook(hook_fn) - wrapper(target, comm_spec_to_use, reduction_stream) + wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap) return gm @@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor = None): + strategies_constructor: StrategiesConstructor = None, + overlap=False): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( gm, solution, strategies_constructor) gm = _size_value_converting(gm, device_mesh) gm = _node_args_converting(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) - gm = _module_params_sharding(gm, device_mesh) + gm = _module_params_sharding(gm, device_mesh, overlap=overlap) return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 387a682a1ad9..23ed0f433731 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc return solution -def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor): +def transform_to_sharded_model(gm: ColoGraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap: bool = False): ''' This method is used to transform the original graph to the sharded graph. The model parameters will be sharded according to the solution and the grad hooks will be added to the sharded graph using the runtime_preparation_pass. The communication node will be added into the graph using the runtime_apply_pass. ''' - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, + solution, + device_mesh, + strategies_constructor, + overlap=overlap) gm = runtime_apply_pass(gm) gm.recompile() sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) @@ -176,6 +182,7 @@ def initialize_model(model: nn.Module, meta_args: Dict[str, torch.Tensor], device_mesh: DeviceMesh, memory_budget: float = -1.0, + overlap: bool = False, save_solver_solution: bool = False, load_solver_solution: bool = False, solution_path: str = None, @@ -189,6 +196,8 @@ def initialize_model(model: nn.Module, device_mesh: the device mesh to execute the model. memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. + overlap(optional): the overlap is used to specify whether to overlap gradient communication and + backward computing. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -211,7 +220,7 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) if return_solution: