Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions colossalai/autochunk/autochunk_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg

from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape


def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
Expand Down Expand Up @@ -276,11 +276,17 @@ def emit_code_with_chunk(

class AutoChunkCodeGen(CodeGen):

def __init__(self, meta_graph, max_memory=None, print_mem=False):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
super().__init__()
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
if print_progress:
get_logger().info("AutoChunk start codegen")

def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
Expand Down
27 changes: 27 additions & 0 deletions colossalai/autochunk/estimate_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
delete_node = []
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
if to_keep is not None:
keep_list = []
for n in nodes_to_delete:
Expand Down Expand Up @@ -135,6 +137,8 @@ def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chun
if user.op in ("placeholder", "output"):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
delete_size = 0
for n in nodes_to_delete:
if n.name in chunk_inputs_names:
Expand Down Expand Up @@ -294,3 +298,26 @@ def estimate_chunk_inference_mem(
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log

def get_active_nodes(self, node_list: List) -> List:
"""
Get active nodes for every node

Args:
node_list (List): _description_

Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
for _, node in enumerate(node_list):
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
active_node_list_log.append(copy.deepcopy(active_node_list))
return active_node_list_log
75 changes: 57 additions & 18 deletions colossalai/autochunk/search_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder


class SearchChunk(object):
Expand Down Expand Up @@ -40,22 +40,48 @@ class SearchChunk(object):
print_mem (bool): print estimated memory
"""

def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.trace_indice.trace_indice()
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
max_memory=max_memory,
)

def _find_peak_node(self, mem_peak):
def _init_trace(self) -> None:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)

# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])

# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.trace_indice()

def _find_peak_node(self, mem_peak: List) -> int:
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
Expand All @@ -73,15 +99,15 @@ def _get_free_var_idx(self) -> List:
free_var_idx.append(idx)
return free_var_idx

def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node

Chunk region starts extending from the peak node, stops where free var num is min

Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
peak_node_idx (int): peak memory node idx
chunk_regions (List): chunk region infos

Returns:
Expand All @@ -97,7 +123,7 @@ def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_reg
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node, -1, -1):
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
Expand All @@ -107,21 +133,23 @@ def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_reg
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node, len(active_node)):
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_end = i
break

for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end

def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
Expand Down Expand Up @@ -154,6 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
Expand Down Expand Up @@ -253,6 +284,9 @@ def search_region(self) -> Dict:
Returns:
chunk_infos (Dict)
"""
if self.print_progress:
get_logger().info("AutoChunk start searching chunk regions")

chunk_infos = []
(
init_mem_peak,
Expand All @@ -272,6 +306,11 @@ def search_region(self) -> Dict:
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)

if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))

if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
Expand Down
5 changes: 4 additions & 1 deletion colossalai/autochunk/trace_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int,
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
input_dict[user_idx] = [None]
else:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None, None
if len(input_dict) == 0:
Expand Down
46 changes: 46 additions & 0 deletions colossalai/autochunk/trace_indice.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, node_list: List[Node]) -> None:
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []

def _init_indice_trace_list(self):
indice_trace_list = []
Expand All @@ -48,6 +50,10 @@ def _init_indice_trace_list(self):
indice_trace_list.append(cur_trace)
return indice_trace_list

def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
self.active_node_list = active_node_list

def _add_indice(self):
"""
Update the count and return it. To record the idx number.
Expand Down Expand Up @@ -493,6 +499,9 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)

for _, node_arg in enumerate(node_args):
Expand All @@ -513,6 +522,9 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()

Expand Down Expand Up @@ -596,6 +608,37 @@ def _assign_view_reshape_indice(self, node: Node, node_idx: int):
}
self.indice_view_list[node] = view_dict

def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return

active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)

def trace_indice(self):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
Expand Down Expand Up @@ -655,3 +698,6 @@ def trace_indice(self):
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")

# limit trace range
self._clear_trace(idx)
8 changes: 8 additions & 0 deletions colossalai/autochunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

from torch.fx.node import Node

from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def get_logger():
return logger


def flat_list(inputs: Any) -> List:
"""
Expand Down
Loading