Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/amp/torch_amp/_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
_is_batch_dims_same = False

# retireve dimensions
# retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/passes/runtime_apply_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return gm


def _act_annotataion_pass(gm: torch.fx.GraphModule):
def _act_annotation_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
Expand Down
4 changes: 2 additions & 2 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return size


def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
Expand Down Expand Up @@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/autochunk/trace_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
return False
return True

def _assgin_single_node_flow(
def _assign_single_node_flow(
self,
arg_node: Node,
start_idx: int,
Expand Down Expand Up @@ -177,7 +177,7 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
flow_flag = self._assign_single_node_flow(
arg,
start_idx,
end_idx,
Expand Down Expand Up @@ -315,7 +315,7 @@ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int,
chunk_info["args"]["prepose_nodes"] = prepose_nodes

def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
Expand Down
8 changes: 4 additions & 4 deletions colossalai/autochunk/trace_indice.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _assign_elementwise_indice(self, node, idx):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)

def _assgin_no_change_indice(self, node, idx):
def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
Expand Down Expand Up @@ -792,7 +792,7 @@ def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
self._add_dim(node_idx, i)
dim_from.reverse()

# inheirt indice from current node
# inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
Expand Down Expand Up @@ -852,7 +852,7 @@ def trace_indice(self) -> None:
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
Expand Down Expand Up @@ -914,7 +914,7 @@ def trace_indice(self) -> None:
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:
Expand Down