Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]

# create data structrure to store costs
# create data structure to store costs
if node not in resharding_costs:
resharding_costs[node] = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
Expand Down Expand Up @@ -212,7 +212,7 @@ def split_input_batch(self, mesh_dim_0):

# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
Expand Down Expand Up @@ -250,7 +250,7 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):

# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
Expand Down Expand Up @@ -298,7 +298,7 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):

# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:

# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memroy cost to be the output
# we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,15 @@ def _compute_size_in_bytes_helper(sharding_spec, meta_data):
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
# if meta_data is not a tensor, we count the memory as 0
element_bytes = 0
total_bytes += element_bytes

else:
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
# if op_data.data is not a tensor, we count the memory as 0
total_bytes = 0

return total_bytes
Expand Down
4 changes: 2 additions & 2 deletions colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CostGraph:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes.

Argument:
Expand Down Expand Up @@ -90,7 +90,7 @@ def _check_tensor_in_node(data):
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# This is necessary because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def graph(self) -> Graph:

def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
Analyses the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_list = []

# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
Expand All @@ -103,7 +103,7 @@ def liveness_analysis(self) -> List[LiveStage]:
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/tensor_shard/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self,
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/testing/pytest_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_for_something():
assert isinstance(name, str)
flag = os.environ.get(name.upper(), '0')

reason = f'Environment varialbe {name} is {flag}'
reason = f'Environment variable {name} is {flag}'
if flag == '1':
return pytest.mark.skipif(False, reason=reason)
else:
Expand Down