From ee2f4496ebe7d28996d5b8acc5cb89302325dcd0 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Thu, 19 Jan 2023 12:42:45 +0800 Subject: [PATCH 1/9] support extramsa (#3) * init * init origin openfold * openfold can run now * remove useless attr * reformat * code format * use repo for simple evoformer * rename * rename * detect repo in test * delete dirs * rename tests * add test for evoformer * optimize import * support reshape for evoformer * fix some bugs * add getitem condition * add doc * partly support evoformer * update test * complete evoformer now * init extramsa * disable ckpt codegen * support some ops and fix some bugs for extramsa * rename * update support for better memory * finish extramsa --- colossalai/autochunk/estimate_memory.py | 9 +- colossalai/autochunk/trace_flow.py | 43 +++-- colossalai/autochunk/trace_indice.py | 56 +++++- colossalai/autochunk/utils.py | 20 ++- .../test_autochunk/test_evoformer_codegen.py | 2 +- tests/test_autochunk/test_extramsa_codegen.py | 164 ++++++++++++++++++ .../test_simple_evoformer_codegen.py | 2 +- .../test_simple_evoformer_search.py | 41 ++--- 8 files changed, 283 insertions(+), 54 deletions(-) create mode 100644 tests/test_autochunk/test_extramsa_codegen.py diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index d386253850a7..21f34481ba70 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -6,12 +6,7 @@ from colossalai.fx.profiler import activation_size, parameter_size -from .utils import ( - delete_free_var_from_last_use, - find_idx_by_name, - get_node_shape, - is_non_compute_node_except_placeholder, -) +from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node class EstimateMemory(object): @@ -240,7 +235,7 @@ def estimate_chunk_inference_mem( elif node.op == "output": continue # no change for non compute node - elif is_non_compute_node_except_placeholder(node): + elif is_non_memory_node(node): act_memory_peak_log.append(act_memory) # node is a compute op # calculate tmp, output node and delete node memory diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 04fa2b3bb480..e657c188ead2 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -118,16 +118,34 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): def _assgin_single_node_flow( self, - arg_node, - start_idx, - end_idx, - cur_node_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ): + arg_node: Node, + start_idx: int, + end_idx: int, + cur_node_dim: int, + cur_node_compute: Dict, + cur_node_source: Dict, + cur_node_fix_dim: List, + all_node_info: Dict, + next_node_list: List, + ) -> bool: + """ + Given the current node and one of its arg node, + this function finds out arg node's chunk dim and fix dim + + Args: + arg_node (Node): input node + start_idx (int): chunk region start + end_idx (int): chunk region end + cur_node_dim (int): current node chunk dim + cur_node_compute (Dict): current node compute dict + cur_node_source (Dict): current node source dict + cur_node_fix_dim (List): current node fix dim + all_node_info (Dict): all node chunk info in the chunk region + next_node_list (List) + + Returns: + bool: True if this node can be added to the flow, vice versa. + """ arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): @@ -142,6 +160,9 @@ def _assgin_single_node_flow( arg_dim = None else: arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + # chunk dim should be None if shape size is 1 + if get_node_shape(arg_node)[arg_dim] == 1: + arg_dim = None else: arg_dim = None @@ -184,7 +205,7 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): # get all valid args arg_list = [] - for arg in cur_node.args: + for arg in cur_node.all_input_nodes: if type(arg) != type(cur_node): continue if is_non_compute_node(arg): diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 862cd6b99ccc..5c2e9b5203b5 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -432,6 +432,38 @@ def _assign_ones_like_indice(self, node: Node, node_idx: int): """ self._assign_all_indice(node, node_idx) + def _assign_cat_indice(self, node: Node, node_idx: int): + """ + Assign indice for cat op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = flat_list(node.args[0]) + self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) + for n in nodes_in[1:]: + self._mark_computation_from_node(n, node) + cat_dim = node.kwargs["dim"] + self._del_dim(node_idx, cat_dim) + self._add_dim(node_idx, cat_dim) + + def _assign_sum_indice(self, node: Node, node_idx: int): + """ + Assign indice for sum op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = flat_list(node.args[0]) + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) + for n in nodes_in[1:]: + self._mark_computation_from_node(n, node) + cat_dim = node.kwargs["dim"] + self._del_dim(node_idx, cat_dim) + def _assign_getitem_indice(self, node: Node, node_idx: int): """ Assign indice for getitem. @@ -442,7 +474,16 @@ def _assign_getitem_indice(self, node: Node, node_idx: int): node_idx (int) """ node_args = flat_list(node.args[1:]) - if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args): + flag = False + for node_arg in node_args: + node_arg_str = str(node_arg) + if any(i == node_arg_str for i in ["None", "Ellipsis"]): + flag = True + break + if "slice" in node_arg_str: + flag = True + break + if flag == False: return # node args should be like [Ellipsis, slice(start, step, end), None] @@ -461,8 +502,11 @@ def _assign_getitem_indice(self, node: Node, node_idx: int): shape_gap = len(node_shape) - len(node_args) + 1 origin_idx_count += shape_gap new_idx_count += shape_gap - # slice(None, None, None) means all indexes, doesn't support other slice - elif "slice(None, None, None)" == node_arg_str: + # slice(None, None, None) means all indexes + elif "slice" in node_arg_str: + if "slice(None, None, None)" != node_arg_str: + self._del_dim(node_idx, new_idx_count) + self._add_dim(node_idx, new_idx_count) origin_idx_count += 1 new_idx_count += 1 # None means a new dim @@ -565,7 +609,7 @@ def trace_indice(self): self._assign_view_reshape_indice(node, idx) elif "unsqueeze" in node.name: self._assign_unsqueeze_indice(node, idx) - elif any(i in node.name for i in ["to", "contiguous"]): + elif any(i in node.name for i in ["to", "contiguous", "clone"]): self._assgin_no_change_indice(node, idx) elif "new_ones" in node.name: self._assign_ones_like_indice(node, idx) @@ -574,6 +618,8 @@ def trace_indice(self): elif node.op == "call_function": if "linear" in node.name: self._assign_linear_indice(node, idx) + elif "cat" in node.name: + self._assign_cat_indice(node, idx) elif "matmul" in node.name: self._assign_matmul_indice(node, idx) elif "softmax" in node.name: @@ -586,6 +632,8 @@ def trace_indice(self): self._assign_dropout_indice(node, idx) elif "einsum" in node.name: self._assign_einsum_indice(node, idx) + elif "sum" in node.name: + self._assign_sum_indice(node, idx) elif "layer_norm" in node.name: self._assign_layernorm_indice(node, idx) elif "getitem" in node.name: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 9c2363b544e2..ff1a64bc359d 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -3,10 +3,12 @@ from torch.fx.node import Node -def flat_list(inputs): +def flat_list(inputs: Any) -> List: """ flat a list by recursion """ + if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)): + return [inputs] res = [] for i in inputs: if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): @@ -16,7 +18,7 @@ def flat_list(inputs): return res -def find_first_tensor_arg(node): +def find_first_tensor_arg(node: Node) -> Node: """ Find the first input tensor arg for a node """ @@ -26,7 +28,7 @@ def find_first_tensor_arg(node): raise RuntimeError() -def is_non_compute_node(node): +def is_non_compute_node(node: Node) -> bool: if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]): return True if "getitem" in node.name: @@ -34,16 +36,26 @@ def is_non_compute_node(node): for node_arg in node_args: if any(i == str(node_arg) for i in ["None", "Ellipsis"]): return False + if "slice" in str(node_arg): + return False return True return False -def get_node_shape(node): +def get_node_shape(node: Node) -> List: if hasattr(node.meta["tensor_meta"], "shape"): return node.meta["tensor_meta"].shape return None +def is_non_memory_node(node: Node) -> bool: + if "getitem" in node.name: + return True + if "output" in node.op: + return True + return is_non_compute_node(node) + + def is_non_compute_node_except_placeholder(node): if "placeholder" in node.op: return False diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py index c5a893eda7cc..ba6a57a51ce3 100644 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): }, ) graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) gm.recompile() # assert we have inserted chunk diff --git a/tests/test_autochunk/test_extramsa_codegen.py b/tests/test_autochunk/test_extramsa_codegen.py new file mode 100644 index 000000000000..2a41452a2ad7 --- /dev/null +++ b/tests/test_autochunk/test_extramsa_codegen.py @@ -0,0 +1,164 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True +except: + HAS_REPO = False + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if CODEGEN_AVAILABLE and is_compatible_with_meta(): + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): + # for memory test + # model = model.cuda() + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # node_mask1 = node_mask.clone() + # pair_mask1 = pair_mask.clone() + # gm(node1, pair1, node_mask1, pair_mask1) + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) + + # test forward + model = model.cuda() + with torch.no_grad(): + non_fx_out = model(node, pair, node_mask, pair_mask) + fx_out = gm(node, pair, node_mask, pair_mask) + + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) + + +def _build_openfold(): + model = ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ).eval().cuda() + return model + + +def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory): + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = _build_openfold() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_chunk_logits": 1024, + }, + ) + interp = MetaInfoProp(meta_graph) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), + MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), + MetaTensor(pair_mask, fake_device="cuda:0"), + ) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_chunk_logits": 1024, + }, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert we have inserted chunk + code = graph.python_code("self").src + # print(code) + assert "chunk_result = None; chunk_size = None;" in code + + _test_fwd(model, gm, node, pair, node_mask, pair_mask) + gpc.destroy() + + +@pytest.mark.skipif( + not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_extramsa_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_extramsa_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_extramsa_codegen(0, 32, 64, None) diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py index 8ab77024c1b9..7fe149c5784d 100644 --- a/tests/test_autochunk/test_simple_evoformer_codegen.py +++ b/tests/test_autochunk/test_simple_evoformer_codegen.py @@ -73,7 +73,7 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): }, ) graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) gm.recompile() # assert we have inserted chunk diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py index 4c591c48319e..89f28d625cbe 100644 --- a/tests/test_autochunk/test_simple_evoformer_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -13,6 +13,7 @@ import colossalai from colossalai.core import global_context as gpc +from colossalai.fx import symbolic_trace from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -28,10 +29,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): if msa_len == 32 and pair_len == 64: if max_memory is None: - target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), - (161, 166), (198, 203), (6, 69)] + target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191), + (161, 166), (198, 203), (7, 57)] elif max_memory == 20: - target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)] + target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)] elif max_memory == 25: target_regions = [(144, 154), (369, 370)] elif max_memory == 30: @@ -41,25 +42,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): else: raise NotImplementedError() - assert len(found_regions) == len( - target_regions), "len of found regions %s doesn't equal len of target regions %s" % ( - str(found_regions), - str(target_regions), - ) - for region in target_regions: - assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % ( - str(region), - msa_len, - pair_len, - str(max_memory), - ) - for region in found_regions: - assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( - str(region), - msa_len, - pair_len, - str(max_memory), - ) + assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % ( + str(found_regions), + str(target_regions), + ) def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory): @@ -78,11 +64,14 @@ def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory): node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace - interp = MetaInfoProp(gm_prop) + meta_graph = symbolic_trace(model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }) # must use symbolic_trace + interp = MetaInfoProp(meta_graph) interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) chunk_infos = codegen.chunk_infos assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len) From 69595cee0e5f09b16c8b8bc431bf3ea7ac8e642a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 15:10:29 +0800 Subject: [PATCH 2/9] set trace range before trace --- colossalai/autochunk/estimate_memory.py | 23 ++++++++++ colossalai/autochunk/search_chunk.py | 57 ++++++++++++++++++------- colossalai/autochunk/trace_indice.py | 4 ++ 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 21f34481ba70..bf0e43637e8e 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -294,3 +294,26 @@ def estimate_chunk_inference_mem( # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory return act_memory_peak_log, act_memory_after_node_log, active_node_list_log + + def get_active_nodes(self, node_list: List) -> List: + """ + Get active nodes for every node + + Args: + node_list (List): _description_ + + Returns: + active_node_list_log (List): active nodes of every node. active nodes refer to + nodes generated but not deleted. + """ + active_node_list = [] + active_node_list_log = [] + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + delete_free_var_from_last_use(user_to_last_uses_no_free_var) + for _, node in enumerate(node_list): + # log active node, only effective without chunk + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + active_node_list_log.append(copy.deepcopy(active_node_list)) + return active_node_list_log diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 236f9697df5d..2c7196c19f59 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -41,13 +41,12 @@ class SearchChunk(object): """ def __init__(self, gm, max_memory=None, print_mem=False) -> None: - self.gm = gm self.print_mem = print_mem self.trace_indice = TraceIndice(list(gm.graph.nodes)) - self.trace_indice.trace_indice() + self.estimate_memory = EstimateMemory() + self._init_trace() self.trace_flow = TraceFlow(self.trace_indice) self.reorder_graph = ReorderGraph(self.trace_indice) - self.estimate_memory = EstimateMemory() self.select_chunk = SelectChunk( self.trace_indice, self.estimate_memory, @@ -55,7 +54,31 @@ def __init__(self, gm, max_memory=None, print_mem=False) -> None: max_memory=max_memory, ) - def _find_peak_node(self, mem_peak): + def _init_trace(self) -> None: + """ + find the max trace range for every node + reduce the computation complexity of trace_indice + """ + # find all max ranges + active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list) + cur_node_idx = len(self._get_free_var_idx()) + max_chunk_region_list = [] + while True: + max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx) + cur_node_idx = max_chunk_region[1] + if cur_node_idx == len(active_nodes) - 1: + break + max_chunk_region_list.append(max_chunk_region) + + # the first and second are always overlaped, merge them + max_chunk_region_list[0] = (0, max_chunk_region_list[1][1]) + max_chunk_region_list.pop(1) + + # set trace range and do the trace + self.trace_indice.set_trace_range(max_chunk_region_list) + self.trace_indice.trace_indice() + + def _find_peak_node(self, mem_peak: List) -> int: max_value = max(mem_peak) max_idx = mem_peak.index(max_value) return max_idx @@ -73,7 +96,7 @@ def _get_free_var_idx(self) -> List: free_var_idx.append(idx) return free_var_idx - def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple: + def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple: """ Search max chunk region according to peak memory node @@ -81,7 +104,7 @@ def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_reg Args: active_node (List): active node status for every node - peak_node (Node): peak memory node + peak_node_idx (int): peak memory node idx chunk_regions (List): chunk region infos Returns: @@ -97,7 +120,7 @@ def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_reg # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num - for i in range(peak_node, -1, -1): + for i in range(peak_node_idx, -1, -1): if active_node_num[i] <= threshold: inside_flag = True if inside_flag and active_node_num[i] > threshold: @@ -107,21 +130,23 @@ def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_reg # from peak_node to len-2 inside_flag = False chunk_region_end = len(active_node) - 1 - for i in range(peak_node, len(active_node)): + for i in range(peak_node_idx, len(active_node)): if active_node_num[i] <= threshold: inside_flag = True if inside_flag and active_node_num[i] > threshold: chunk_region_end = i break - for i in chunk_regions: - region = i["region"] - if chunk_region_start >= region[0] and chunk_region_end <= region[1]: - return None - elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): - chunk_region_start = region[1] + 1 - elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): - chunk_region_end = region[0] - 1 + # avoid chunk regions overlap + if chunk_regions is not None: + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + chunk_region_start = region[1] + 1 + elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 5c2e9b5203b5..28bf4d0d2107 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -33,6 +33,7 @@ def __init__(self, node_list: List[Node]) -> None: self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} self.indice_count = -1 + self.trace_range = [] def _init_indice_trace_list(self): indice_trace_list = [] @@ -48,6 +49,9 @@ def _init_indice_trace_list(self): indice_trace_list.append(cur_trace) return indice_trace_list + def set_trace_range(self, trace_range: List) -> None: + self.trace_range = trace_range + def _add_indice(self): """ Update the count and return it. To record the idx number. From 3c1bfb169a79375ef94958ecbe96f0d6e2fe436f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 16:09:24 +0800 Subject: [PATCH 3/9] support limit trace range --- colossalai/autochunk/search_chunk.py | 6 ++--- colossalai/autochunk/trace_indice.py | 35 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 2c7196c19f59..4a10ce908337 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -70,9 +70,9 @@ def _init_trace(self) -> None: break max_chunk_region_list.append(max_chunk_region) - # the first and second are always overlaped, merge them - max_chunk_region_list[0] = (0, max_chunk_region_list[1][1]) - max_chunk_region_list.pop(1) + # nothing to limit for the first range + max_chunk_region_list = max_chunk_region_list[1:] + max_chunk_region_list[0] = (0, max_chunk_region_list[0][1]) # set trace range and do the trace self.trace_indice.set_trace_range(max_chunk_region_list) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 28bf4d0d2107..c4d7a1b006b7 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -600,6 +600,38 @@ def _assign_view_reshape_indice(self, node: Node, node_idx: int): } self.indice_view_list[node] = view_dict + def _clear_trace(self, node_idx: int) -> None: + """ + clear too far trace to speed up computation + """ + trace_range = None + for i in range(len(self.trace_range)): + if self.trace_range[i][1] == node_idx: + if i <= 4: + break + # use previous range's start instead of this one's start + # 5 is the min safe range + trace_range = (self.trace_range[i - 5][0], self.trace_range[i][1]) + break + if self.trace_range[i][1] > node_idx: + break + if trace_range is None: + return + + for i in range(trace_range[0], trace_range[1] + 1): + trace = self.indice_trace_list[i] + # clear compute + for dim_compute in trace["compute"]: + for i in range(len(dim_compute) - 1, -1, -1): + if dim_compute[i] < trace_range[0]: + dim_compute.pop(i) + continue + # clear source + for dim_source in trace["source"]: + for k in list(dim_source.keys()): + if k < trace_range[0]: + dim_source.pop(k) + def trace_indice(self): for idx, node in enumerate(self.node_list): if node.op == "placeholder": @@ -659,3 +691,6 @@ def trace_indice(self): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + + # limit trace range + self._clear_trace(idx) From 03606206fab6159bc1872e094d215380eb4da361 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 16:23:09 +0800 Subject: [PATCH 4/9] init evo stack --- .../test_evoformer_stack_codegen.py | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/test_autochunk/test_evoformer_stack_codegen.py diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py new file mode 100644 index 000000000000..0d12190d743b --- /dev/null +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -0,0 +1,167 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True +except: + HAS_REPO = False + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if CODEGEN_AVAILABLE and is_compatible_with_meta(): + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): + # for memory test + # model = model.cuda() + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # node_mask1 = node_mask.clone() + # pair_mask1 = pair_mask.clone() + # gm(node1, pair1, node_mask1, pair_mask1) + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) + + # test forward + model = model.cuda() + with torch.no_grad(): + non_fx_out = model(node, pair, node_mask, pair_mask) + fx_out = gm(node, pair, node_mask, pair_mask) + + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) + + +def _build_openfold(): + model = EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=4, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ).eval().cuda() + return model + + +def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = _build_openfold() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + interp = MetaInfoProp(meta_graph) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), + MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), + MetaTensor(pair_mask, fake_device="cuda:0"), + ) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert we have inserted chunk + code = graph.python_code("self").src + # print(code) + assert "chunk_result = None; chunk_size = None;" in code + + _test_fwd(model, gm, node, pair, node_mask, pair_mask) + gpc.destroy() + + +@pytest.mark.skipif( + not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_evoformer_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_evoformer_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_evoformer_codegen(0, 32, 64, 24) From 358f05e3946e51bcd968580c7d5fbf744203f340 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 16:30:28 +0800 Subject: [PATCH 5/9] rename --- .../test_evoformer_stack_codegen.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py index 0d12190d743b..2c17ec4339b4 100644 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -77,7 +77,7 @@ def _build_openfold(): return model -def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): +def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -110,12 +110,8 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): }, ) interp = MetaInfoProp(meta_graph) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), - MetaTensor(pair, fake_device="cuda:0"), - MetaTensor(node_mask, fake_device="cuda:0"), - MetaTensor(pair_mask, fake_device="cuda:0"), - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None) codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) # trace and recompile @@ -153,9 +149,9 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): @pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_evoformer_codegen(msa_len, pair_len, max_memory): +def test_evoformer_stack_codegen(msa_len, pair_len, max_memory): run_func = partial( - _test_evoformer_codegen, + _test_evoformer_stack_codegen, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -164,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_evoformer_codegen(0, 32, 64, 24) + _test_evoformer_stack_codegen(0, 32, 64, 24) From 2c09f775a95342a4e7e254540d376f9c4263ec60 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 17:26:48 +0800 Subject: [PATCH 6/9] update logger, fix no user node --- colossalai/autochunk/autochunk_codegen.py | 12 +++++++++--- colossalai/autochunk/estimate_memory.py | 4 ++++ colossalai/autochunk/search_chunk.py | 15 +++++++++++++-- colossalai/autochunk/trace_indice.py | 6 ++++++ colossalai/autochunk/utils.py | 8 ++++++++ .../test_evoformer_stack_codegen.py | 8 ++++---- 6 files changed, 44 insertions(+), 9 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index de5e7356bbfd..8c3155a60685 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from .search_chunk import SearchChunk -from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape +from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: @@ -276,11 +276,17 @@ def emit_code_with_chunk( class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None, print_mem=False): + def __init__(self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False) -> None: super().__init__() # find the chunk regions - self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress) self.chunk_infos = self.search_chunk.search_region() + if print_progress: + get_logger().info("AutoChunk start codegen") def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index bf0e43637e8e..a03a5413bc34 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -43,6 +43,8 @@ def _get_delete_node(self, user, user_to_last_uses, to_keep=None): delete_node = [] if user.op not in ("output",): nodes_to_delete = user_to_last_uses.get(user, []) + if len(user.users) == 0: + nodes_to_delete.append(user) if to_keep is not None: keep_list = [] for n in nodes_to_delete: @@ -135,6 +137,8 @@ def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chun if user.op in ("placeholder", "output"): return 0 nodes_to_delete = user_to_last_uses.get(user, []) + if len(user.users) == 0: + nodes_to_delete.append(user) delete_size = 0 for n in nodes_to_delete: if n.name in chunk_inputs_names: diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 4a10ce908337..5fd8621ca534 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,7 @@ from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -40,8 +40,9 @@ class SearchChunk(object): print_mem (bool): print estimated memory """ - def __init__(self, gm, max_memory=None, print_mem=False) -> None: + def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: self.print_mem = print_mem + self.print_progress = print_progress self.trace_indice = TraceIndice(list(gm.graph.nodes)) self.estimate_memory = EstimateMemory() self._init_trace() @@ -75,6 +76,8 @@ def _init_trace(self) -> None: max_chunk_region_list[0] = (0, max_chunk_region_list[0][1]) # set trace range and do the trace + if self.print_progress: + get_logger().info("AutoChunk start tracing indice") self.trace_indice.set_trace_range(max_chunk_region_list) self.trace_indice.trace_indice() @@ -278,6 +281,9 @@ def search_region(self) -> Dict: Returns: chunk_infos (Dict) """ + if self.print_progress: + get_logger().info("AutoChunk start searching chunk regions") + chunk_infos = [] ( init_mem_peak, @@ -297,6 +303,11 @@ def search_region(self) -> Dict: _, active_node, ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos) + + if self.print_progress: + get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % + (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index c4d7a1b006b7..b109fe964fad 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -497,6 +497,9 @@ def _assign_getitem_indice(self, node: Node, node_idx: int): new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args]) for _ in range(new_dim_num): self._del_dim(node_idx, 0) + delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args]) + for _ in range(delete_dim_num): + self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx) for _, node_arg in enumerate(node_args): @@ -517,6 +520,9 @@ def _assign_getitem_indice(self, node: Node, node_idx: int): elif "None" == node_arg_str: self._add_dim(node_idx, new_idx_count) new_idx_count += 1 + elif "0" == node_arg_str: + self._del_dim(node_idx, new_idx_count) + origin_idx_count += 1 else: raise NotImplementedError() diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index ff1a64bc359d..e870685122e3 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -2,6 +2,14 @@ from torch.fx.node import Node +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def get_logger(): + return logger + def flat_list(inputs: Any) -> List: """ diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py index 2c17ec4339b4..37a66b888d79 100644 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -42,8 +42,8 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask # test forward model = model.cuda() with torch.no_grad(): - non_fx_out = model(node, pair, node_mask, pair_mask) - fx_out = gm(node, pair, node_mask, pair_mask) + non_fx_out = model(node, pair, node_mask, pair_mask, None) + fx_out = gm(node, pair, node_mask, pair_mask, None) assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( @@ -112,7 +112,7 @@ def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory): interp = MetaInfoProp(meta_graph) interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"), MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=True) # trace and recompile # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer @@ -160,4 +160,4 @@ def test_evoformer_stack_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_evoformer_stack_codegen(0, 32, 64, 24) + _test_evoformer_stack_codegen(0, 32, 64, 28) From 5a1fe092921a19aebc84759b576a8cb67b941947 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 19 Jan 2023 18:32:07 +0800 Subject: [PATCH 7/9] support full evoformer run --- colossalai/autochunk/search_chunk.py | 5 ++++- colossalai/autochunk/trace_indice.py | 17 +++++++++-------- .../test_evoformer_stack_codegen.py | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 5fd8621ca534..a8619671268b 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -78,7 +78,7 @@ def _init_trace(self) -> None: # set trace range and do the trace if self.print_progress: get_logger().info("AutoChunk start tracing indice") - self.trace_indice.set_trace_range(max_chunk_region_list) + self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes) self.trace_indice.trace_indice() def _find_peak_node(self, mem_peak: List) -> int: @@ -182,6 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis # dim size cannot be 1 if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): continue + # must have users + if len(end_node.users) == 0: + continue # check index source align if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): continue diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index b109fe964fad..827f60d8b53d 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -34,6 +34,7 @@ def __init__(self, node_list: List[Node]) -> None: self.indice_view_list = {} self.indice_count = -1 self.trace_range = [] + self.active_node_list = [] def _init_indice_trace_list(self): indice_trace_list = [] @@ -49,8 +50,9 @@ def _init_indice_trace_list(self): indice_trace_list.append(cur_trace) return indice_trace_list - def set_trace_range(self, trace_range: List) -> None: + def set_trace_range(self, trace_range: List, active_node_list: List) -> None: self.trace_range = trace_range + self.active_node_list = active_node_list def _add_indice(self): """ @@ -613,29 +615,28 @@ def _clear_trace(self, node_idx: int) -> None: trace_range = None for i in range(len(self.trace_range)): if self.trace_range[i][1] == node_idx: - if i <= 4: - break - # use previous range's start instead of this one's start - # 5 is the min safe range - trace_range = (self.trace_range[i - 5][0], self.trace_range[i][1]) + trace_range = (self.trace_range[i][0], self.trace_range[i][1]) break if self.trace_range[i][1] > node_idx: break if trace_range is None: return + active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1] + active_nodes = set(flat_list(active_nodes)) + active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes] for i in range(trace_range[0], trace_range[1] + 1): trace = self.indice_trace_list[i] # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): - if dim_compute[i] < trace_range[0]: + if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes: dim_compute.pop(i) continue # clear source for dim_source in trace["source"]: for k in list(dim_source.keys()): - if k < trace_range[0]: + if k < trace_range[0] and k not in active_nodes: dim_source.pop(k) def trace_indice(self): diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py index 37a66b888d79..be1b6b6775d5 100644 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -160,4 +160,4 @@ def test_evoformer_stack_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_evoformer_stack_codegen(0, 32, 64, 28) + _test_evoformer_stack_codegen(0, 32, 64, None) From cb8cb941cdbfecbede4ef50dee47420c05d5589b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 20 Jan 2023 00:40:13 +0800 Subject: [PATCH 8/9] fix a bug in input nodes --- colossalai/autochunk/trace_flow.py | 5 ++++- tests/test_autochunk/test_evoformer_stack_codegen.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index e657c188ead2..830b4629ec1e 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -281,7 +281,10 @@ def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, if chunk_dim is not None: user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] + if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1: + input_dict[user_idx] = [None] + else: + input_dict[user_idx] = user_source[input_node_idx] else: return None, None if len(input_dict) == 0: diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py index be1b6b6775d5..e50fb2d0283c 100644 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -35,7 +35,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask # pair1 = pair.clone() # node_mask1 = node_mask.clone() # pair_mask1 = pair_mask.clone() - # gm(node1, pair1, node_mask1, pair_mask1) + # gm(node1, pair1, node_mask1, pair_mask1, None) # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) @@ -64,7 +64,7 @@ def _build_openfold(): c_s=384, no_heads_msa=8, no_heads_pair=4, - no_blocks=4, # 48 + no_blocks=2, # 48 transition_n=4, msa_dropout=0.15, pair_dropout=0.25, From 5648bae413f78490a426069dd1f389b5755c9d7b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 20 Jan 2023 10:05:05 +0800 Subject: [PATCH 9/9] turn off log --- tests/test_autochunk/test_evoformer_stack_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py index e50fb2d0283c..5fabb27028f9 100644 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -112,7 +112,7 @@ def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory): interp = MetaInfoProp(meta_graph) interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"), MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=True) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False) # trace and recompile # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer