Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/autochunk/autochunk_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
replace node name
"""
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]
if source in context:
context = context.replace(source, target)
break
return context


Expand All @@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
"""
if node_name not in reshape_size_dict:
return context
for size_name, size_value in reshape_size_dict[node_name].items():
context = context.replace(size_name, size_value)
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
return context


Expand Down
67 changes: 20 additions & 47 deletions colossalai/autochunk/estimate_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def _get_output_node_size(self, n):

def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == "placeholder":
if n.op == "placeholder" and get_node_shape(n) is not None:
new_active.append(n.name)
for i in new_active:
if i not in active_list:
if i not in active_list and get_node_shape(n) is not None:
active_list.append(i)

def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
Expand Down Expand Up @@ -77,15 +77,11 @@ def _remove_deactive_node(self, user, user_to_last_uses, active_list):
if i in active_list:
active_list.remove(i)

def _get_chunk_inputs_size(
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
):
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [
find_idx_by_name(i.name, node_list) for i in chunk_input_users
]
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
Expand All @@ -112,9 +108,7 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
not_contiguous_ops = ["permute"]
inherit_contiguous_ops = ["transpose", "view"]

if node.op == "call_function" and any(
n in node.name for n in ["matmul", "reshape"]
):
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
Expand All @@ -125,9 +119,7 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == "call_method" and any(
i in node.name for i in not_contiguous_ops
):
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
Expand All @@ -142,9 +134,7 @@ def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
else:
return float(chunk_size) / node_shape[chunk_dim]

def _get_chunk_delete_node_size(
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
):
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if user.op in ("placeholder", "output"):
Expand Down Expand Up @@ -196,7 +186,7 @@ def estimate_chunk_inference_mem(
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory = 0.0
Expand All @@ -212,7 +202,7 @@ def estimate_chunk_inference_mem(
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_names = []

if use_chunk:
Expand All @@ -221,23 +211,18 @@ def estimate_chunk_inference_mem(
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]

for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(
chunk_outputs[chunk_region_idx]
) / (1024**2)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)

# determine chunk ratio for current node
if chunk_within:
Expand All @@ -262,22 +247,13 @@ def estimate_chunk_inference_mem(
else:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory += (
self._get_contiguous_memory(node, not_contiguous_list)
* chunk_ratio
/ (1024**2)
)
act_memory += (
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= (
self._get_contiguous_memory(node, not_contiguous_list, delete=True)
* chunk_ratio
/ (1024**2)
)
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
(1024**2))
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if chunk_within:
Expand All @@ -288,19 +264,16 @@ def estimate_chunk_inference_mem(
chunk_inputs_names,
) / (1024**2)
else:
act_memory -= self._get_delete_node_size(
node, user_to_last_uses_no_free_var, chunk_inputs_names
) / (1024**2)
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
chunk_inputs_names) / (1024**2)

# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)

# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
act_memory -= (
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],
Expand Down
83 changes: 22 additions & 61 deletions colossalai/autochunk/search_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder


class SearchChunk(object):
Expand Down Expand Up @@ -73,13 +69,11 @@ def _get_free_var_idx(self) -> List:
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder":
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx

def _search_max_chunk_region(
self, active_node: List, peak_node: Node, chunk_regions: List
) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
"""
Search max chunk region according to peak memory node

Expand Down Expand Up @@ -124,15 +118,9 @@ def _search_max_chunk_region(
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (
region[0] <= chunk_region_start <= region[1]
and chunk_region_end > region[1]
):
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (
region[0] <= chunk_region_end <= region[1]
and chunk_region_start < region[0]
):
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end

Expand Down Expand Up @@ -164,25 +152,16 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (
get_node_shape(end_node)[end_dim] == 1
or get_node_shape(start_node)[start_dim] == 1
):
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# check index source align
if not self.trace_flow.check_index_source(
start_dim, start_node, start_idx, end_dim, end_node
):
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(
start_idx, end_dim, end_node, end_idx
):
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(
start_idx, start_dim, end_idx, end_dim
)
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
Expand All @@ -191,9 +170,7 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
chunk_infos.append(chunk_info)
return chunk_infos

def _search_possible_chunk_regions(
self, max_chunk_region: Tuple, peak_node: Node
) -> List:
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.

Expand All @@ -206,28 +183,23 @@ def _search_possible_chunk_regions(
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg
):
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)

for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(
self.trace_indice.node_list[start_idx]
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
continue

# select free dim
chunk_info = self._find_chunk_info(
input_trace, output_trace, start_idx, end_idx
)
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
Expand Down Expand Up @@ -256,17 +228,12 @@ def _step_search(
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
active_node, peak_node, chunk_infos
)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region

Expand All @@ -291,9 +258,7 @@ def search_region(self) -> Dict:
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
mem_peak = init_mem_peak

while True:
Expand All @@ -306,14 +271,10 @@ def search_region(self) -> Dict:
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos, print_mem=True
)
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
return chunk_infos
Loading