Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _process_sharding_spec(sharding_spec):
return gm


def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
Expand Down Expand Up @@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:

def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):

def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)

param.register_hook(hook_fn)

wrapper(param, comm_spec_to_use, reduction_stream)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)

sharded_buffer_dict = {}
# apply the sharding spec of buffers
Expand Down Expand Up @@ -441,15 +444,18 @@ def hook_fn(grad):
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:

def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):

def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)

param.register_hook(hook_fn)

wrapper(target, comm_spec_to_use, reduction_stream)
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
return gm


Expand All @@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None):
strategies_constructor: StrategiesConstructor = None,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh)
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)

return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
19 changes: 14 additions & 5 deletions colossalai/auto_parallel/tensor_shard/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution


def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor):
def transform_to_sharded_model(gm: ColoGraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
Expand Down Expand Up @@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
Expand All @@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
Expand All @@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)

gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)

if return_solution:
Expand Down