From f5c5d4c4f385f115d3a0f569c18e4b66106ca4ff Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 27 Oct 2022 10:57:07 +0800 Subject: [PATCH 001/209] init --- chunk_codegen.py | 1047 ++++++++++++++++++++++++++++++++++++++++++ chunk_codegen_run.py | 177 +++++++ 2 files changed, 1224 insertions(+) create mode 100644 chunk_codegen.py create mode 100644 chunk_codegen_run.py diff --git a/chunk_codegen.py b/chunk_codegen.py new file mode 100644 index 000000000000..684028c014de --- /dev/null +++ b/chunk_codegen.py @@ -0,0 +1,1047 @@ +import colossalai +import torch +from typing import List, Callable, Any, Tuple, Dict, Iterable + +try: + from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin + CODEGEN_AVAILABLE = True +except: + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin + from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name + CODEGEN_AVAILABLE = False + +if CODEGEN_AVAILABLE: + __all__ = ['ActivationCheckpointCodeGen'] +else: + __all__ = ['python_code_with_activation_checkpoint'] + + +def _gen_saved_tensors_hooks(): + """ + Generate saved tensors hooks + """ + + pack_hook = """def pack_hook_input(self, x): + if getattr(x, "offload", False): + return (x.device, x.cpu()) + else: + return x + +def pack_hook_no_input(self, x): + if getattr(x, "offload", True): + return (x.device, x.cpu()) + else: + return x +""" + + unpack_hook = """def unpack_hook(self, packed): + if isinstance(packed, tuple): + device, tensor = packed + return tensor.to(device) + else: + return packed +""" + + return pack_hook, unpack_hook + + +def _gen_save_tensors_hooks_context(offload_input=True) -> str: + """Generate customized saved_tensors_hooks + + Args: + offload_input (bool, optional): whether we need offload input, if offload_input=False, + we will use self.pack_hook_no_input instead. Defaults to True. + + Returns: + str: generated context + """ + + if offload_input: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" + else: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" + return context + + +def _gen_save_on_cpu_context(): + """ + Generate save on cpu context + """ + + context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" + return context + + +def _find_input_and_output_nodes(nodes: List[Node]): + """ + Find the input and output node names which are not found in the given list of nodes. + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + node_repr = repr(input_node) + if input_node not in nodes and node_repr not in input_nodes: + input_nodes.append(node_repr) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + node_repr = repr(node) + if output_node not in nodes and node_repr not in output_nodes: + output_nodes.append(node_repr) + + return input_nodes, output_nodes + + +def _find_ckpt_regions(nodes: List[Node]): + """ + Find the checkpoint regions given a list of consecutive nodes. The outputs will be list + of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_nodes = [] + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_checkpoint'): + act_ckpt_label = node.activation_checkpoint + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and not hasattr(node, 'activation_checkpoint'): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + return ckpt_regions + + +def _find_offload_regions(nodes: List[Node]): + """This function is to find the offload regions + In pofo algorithm, during annotation, we will annotate the offload region with the + list in the form of [idx, offload_input, offload_bar]. idx indicates the offload + region's index, offload_input is a bool type indicates whether we need to offload + the input, offload_bar is a bool type indicates whether we need to offload all the + intermediate x_bars of this region. + """ + offload_regions = [] + offload_labels = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): + act_offload_label = node.activation_offload + + if current_region == None: + current_region = act_offload_label + start = idx + offload_labels.append(act_offload_label) + + if act_offload_label != current_region: + assert start != -1 + offload_regions.append((start, idx - 1)) + offload_labels.append(act_offload_label) + current_region = act_offload_label + start = idx + end = -1 + + else: + if current_region is not None: + end = idx - 1 + assert start != -1 and end != -1 + offload_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + return offload_regions, offload_labels + + +def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: + """ + Generate the checkpoint function definition + """ + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" + + +def _gen_ckpt_output(output_vars: List[str]) -> str: + """ + Generate the return statement for checkpoint region + """ + return f"return {', '.join(output_vars)}" + + +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): + """ + Generate the checkpoint function call code text + """ + outputs = ', '.join(output_vars) + inputs = ', '.join(input_vars) + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + + +def _end_of_ckpt(node: Node, check_idx: int) -> bool: + """Check if the node could end the ckpt region + + Args: + node (Node): torch.fx.Node + check_idx (int): the index of checkpoint level for + nested checkpoint + + Returns: + bool + """ + if hasattr(node, "activation_checkpoint"): + if isinstance(node.activation_checkpoint, list): + return node.activation_checkpoint[check_idx] == None + else: + return False + else: + return True + + +def _find_nested_ckpt_regions(nodes, check_idx=0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_checkpoint'): + if isinstance(getattr(node, 'activation_checkpoint'), int): + act_ckpt_label = node.activation_checkpoint + else: + act_ckpt_label = node.activation_checkpoint[check_idx] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and _end_of_ckpt(node, check_idx): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + + if current_region is not None: + end = len(nodes) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + level=0, + in_ckpt=False): + """Emit ckpt fuction in nested way + + Args: + body: forward code, in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code, in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # if the current checkpoint function use int as label, using old generation method + if isinstance(node_list[0].activation_checkpoint, int): + label = node_list[0].activation_checkpoint + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + usage += "\n" + body.append(usage) + + # use nested ckpt function codegen + else: + # label given by each layer, e.g. if you are currently at level [0, 1, 1] + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if level + 1 < len(node_list[0].activation_checkpoint): + ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, + delete_unused_value_func, level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func += ckpt_func_buffer + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # this flag is to prevent repeated insert of save tensors + # hooks definition in ckpt_func + is_hook_inserted = False + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + + if node_idx in offload_starts: + offload_label = offload_labels[offload_starts.index(node_idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + if within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if node_idx in offload_ends: + within_offload_region = False + + node_idx += 1 + + +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + # find the activation checkpoint regions + ckpt_regions = _find_ckpt_regions(nodes) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + input_vars = [] + output_vars = [] + within_ckpt_region = False + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # use this variable to avoid inserting hook functions + # to ckpt_func repeatedly + is_hook_inserted = False + + # find the input and output var names for each region + for idx, (start, end) in enumerate(ckpt_regions): + ckpt_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) + input_vars.append(inputs) + output_vars.append(outputs) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # append code text to body + for idx, node in enumerate(node_list): + # if this is the first node of the ckpt region + # append the ckpt function defition + if idx in start_idx: + label = start_idx.index(idx) + ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) + ckpt_func.append(f'{ckpt_fn_def}\n') + within_ckpt_region = True + + if idx in offload_starts: + offload_label = offload_labels[offload_starts.index(idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + # NOTE: currently we separate body and ckpt_func definition + if within_ckpt_region: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + elif within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if idx in end_idx: + # if this is the last node of the ckpt region + # generate return statement + label = end_idx.index(idx) + return_statement = _gen_ckpt_output(output_vars[label]) + return_statement = f' {return_statement}\n\n' + ckpt_func.append(return_statement) + + # we need to check if the checkpoint need to offload the input + start_node_idx = start_idx[label] + if hasattr(node_list[start_node_idx], 'activation_offload'): + activation_offload = node_list[start_node_idx].activation_offload + else: + activation_offload = False + + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + non_leaf_input = 0 + for var in input_vars[label]: + input_node = next(item for item in node_list if item.name == var) + if input_node.op != "placeholder": + non_leaf_input = 1 + for user in input_node.users: + if hasattr(user, "activation_checkpoint"): + if user.activation_checkpoint == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + + # if all the inputs are leaf nodes, we need to set use_reentrant = False + if not non_leaf_input: + use_reentrant = False + + # generate checkpoint function call in a new line + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) + usage += '\n' + body.append(usage) + within_ckpt_region = False + + if idx in offload_ends: + within_offload_region = False + + +if CODEGEN_AVAILABLE: + + class ActivationCheckpointCodeGen(CodeGen): + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue + prologue = prologue + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_) + +else: + + def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: + """ + This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate + code for activation checkpoint. + """ + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + # This is a generic type, e.g. typing.List[torch.Tensor] + if hasattr(o, '__origin__'): + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + return f'{origin_typename}[{",".join(args)}]' + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + if self._pytree_info is None: + body.append(f'return {repr(node.args[0])}') + else: + body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + if self._pytree_info is not None: + orig_args = self._pytree_info.orig_args + has_orig_self = (orig_args[0] == 'self') + if has_orig_self: + free_vars.insert(0, 'self') + if len(free_vars) > 0: # pytree has placeholders in it + body.insert( + 0, + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + else: + orig_args = free_vars + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + ckpt_func = ''.join(ckpt_func) + + # If the original function didn't have self as its first argument, we + # would have added it. + if len(orig_args) == 0 or orig_args[0] != 'self': + orig_args.insert(0, 'self') + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + fn_code = f""" +{wrap_stmts} + +{ckpt_func} +def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: +{code}""" + return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py new file mode 100644 index 000000000000..9ac399a29b51 --- /dev/null +++ b/chunk_codegen_run.py @@ -0,0 +1,177 @@ +import copy +import torch +import torch.nn.functional as F +import pytest +import torch.multiprocessing as mp +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule + +try: + from chunk_codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from chunk_codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyNet(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear0 = torch.nn.Linear(4, 4) + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear0(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + return x + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): + + # test forward + non_fx_out = model(data) + fx_out = gm(data) + assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + + # test barckward + loss0 = non_fx_out.sum() + loss0.backward() + loss1 = fx_out.sum() + loss1.backward() + assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + + +def _run_offload_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear1": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear2": + setattr(node, "activation_offload", [1, True, True]) + if node.name == "linear4": + setattr(node, "activation_offload", [2, False, True]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_offload_codegen, nprocs=1) + + +def _run_offload_codegen_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear1": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear2": + setattr(node, "activation_offload", [1, True, True]) + if node.name == "linear4": + setattr(node, "activation_offload", [2, False, True]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_offload_codegen_torch11, nprocs=1) + + +if __name__ == "__main__": + _run_offload_codegen(0) From 87cddf7e147f8db1c9710eb37961c489c09bd5b9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 27 Oct 2022 16:40:19 +0800 Subject: [PATCH 002/209] rename and remove useless func --- chunk_codegen.py | 398 +++---------------------------------------- chunk_codegen_run.py | 69 +------- 2 files changed, 27 insertions(+), 440 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 684028c014de..09fda2b988eb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -12,7 +12,7 @@ CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: - __all__ = ['ActivationCheckpointCodeGen'] + __all__ = ['ChunkCodeGen'] else: __all__ = ['python_code_with_activation_checkpoint'] @@ -375,7 +375,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -392,21 +392,21 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod end_idx = [item[1] for item in ckpt_regions] # find the offload regions - offload_regions, offload_labels = _find_offload_regions(nodes) - offload_starts = [item[0] for item in offload_regions] - offload_ends = [item[1] for item in offload_regions] - offload_inputs = [] - offload_outputs = [] - within_offload_region = False + chunk_regions, chunk_labels = _find_offload_regions(nodes) + chunk_starts = [item[0] for item in chunk_regions] + chunk_ends = [item[1] for item in chunk_regions] + chunk_inputs = [] + chunk_outputs = [] + within_chunk_region = False node_list = list(nodes) # find the input and output var names for each offload region - for idx, (start, end) in enumerate(offload_regions): + for idx, (start, end) in enumerate(chunk_regions): offload_node_list = node_list[start:end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) - offload_inputs.append(inputs) - offload_outputs.append(outputs) + chunk_inputs.append(inputs) + chunk_outputs.append(outputs) # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func @@ -427,10 +427,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod else: node = node_list[node_idx] - if node_idx in offload_starts: - offload_label = offload_labels[offload_starts.index(node_idx)] - _, offload_input, offload_bar = offload_label - within_offload_region = True + if node_idx in chunk_starts: + chunk_label = chunk_labels[chunk_starts.index(node_idx)] + _, chunk_input, chunk_bar = chunk_label + within_chunk_region = True # insert hook functions if needed if not is_hook_inserted: @@ -438,20 +438,20 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") is_hook_inserted = True - if offload_input and offload_bar: + if chunk_input and chunk_bar: body.append(_gen_save_on_cpu_context()) - elif offload_input: - for par in offload_inputs[offload_label[0]]: + elif chunk_input: + for par in chunk_inputs[chunk_label[0]]: body.append(f"setattr({par}, 'offload', True)\n") body.append(_gen_save_tensors_hooks_context(offload_input=True)) else: - for par in offload_inputs[offload_label[0]]: + for par in chunk_inputs[chunk_label[0]]: body.append(f"setattr({par}, 'offload', False)\n") body.append(_gen_save_tensors_hooks_context(offload_input=False)) - if within_offload_region: + if within_chunk_region: emit_node_func(node, body) body[-1] = ' ' + body[-1] delete_unused_value_func(node, body) @@ -460,150 +460,15 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod emit_node_func(node, body) delete_unused_value_func(node, body) - if node_idx in offload_ends: - within_offload_region = False + if node_idx in chunk_ends: + within_chunk_region = False node_idx += 1 -def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): - # find the activation checkpoint regions - ckpt_regions = _find_ckpt_regions(nodes) - start_idx = [item[0] for item in ckpt_regions] - end_idx = [item[1] for item in ckpt_regions] - input_vars = [] - output_vars = [] - within_ckpt_region = False - - # find the offload regions - offload_regions, offload_labels = _find_offload_regions(nodes) - offload_starts = [item[0] for item in offload_regions] - offload_ends = [item[1] for item in offload_regions] - offload_inputs = [] - offload_outputs = [] - within_offload_region = False - - node_list = list(nodes) - - # use this variable to avoid inserting hook functions - # to ckpt_func repeatedly - is_hook_inserted = False - - # find the input and output var names for each region - for idx, (start, end) in enumerate(ckpt_regions): - ckpt_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) - input_vars.append(inputs) - output_vars.append(outputs) - - # find the input and output var names for each offload region - for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(offload_node_list) - offload_inputs.append(inputs) - offload_outputs.append(outputs) - - # append code text to body - for idx, node in enumerate(node_list): - # if this is the first node of the ckpt region - # append the ckpt function defition - if idx in start_idx: - label = start_idx.index(idx) - ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - ckpt_func.append(f'{ckpt_fn_def}\n') - within_ckpt_region = True - - if idx in offload_starts: - offload_label = offload_labels[offload_starts.index(idx)] - _, offload_input, offload_bar = offload_label - within_offload_region = True - - # insert hook functions if needed - if not is_hook_inserted: - pack_hook, unpack_hook = _gen_saved_tensors_hooks() - ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") - is_hook_inserted = True - - if offload_input and offload_bar: - body.append(_gen_save_on_cpu_context()) - - elif offload_input: - for par in offload_inputs[offload_label[0]]: - body.append(f"setattr({par}, 'offload', True)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=True)) - - else: - for par in offload_inputs[offload_label[0]]: - body.append(f"setattr({par}, 'offload', False)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=False)) - - # NOTE: emit_node does not emit a string with newline. It depends - # on delete_unused_values to append one - # NOTE: currently we separate body and ckpt_func definition - if within_ckpt_region: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - elif within_offload_region: - emit_node_func(node, body) - body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body) - - else: - emit_node_func(node, body) - delete_unused_value_func(node, body) - - if idx in end_idx: - # if this is the last node of the ckpt region - # generate return statement - label = end_idx.index(idx) - return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n\n' - ckpt_func.append(return_statement) - - # we need to check if the checkpoint need to offload the input - start_node_idx = start_idx[label] - if hasattr(node_list[start_node_idx], 'activation_offload'): - activation_offload = node_list[start_node_idx].activation_offload - else: - activation_offload = False - - # we need to check if the checkpoint need use_reentrant=False - use_reentrant = True - non_leaf_input = 0 - for var in input_vars[label]: - input_node = next(item for item in node_list if item.name == var) - if input_node.op != "placeholder": - non_leaf_input = 1 - for user in input_node.users: - if hasattr(user, "activation_checkpoint"): - if user.activation_checkpoint == label: - if user.op == "call_module": - if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): - use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace - - elif user.op == "call_function": - if "inplace" in user.kwargs: - use_reentrant = not user.kwargs["inplace"] - - # if all the inputs are leaf nodes, we need to set use_reentrant = False - if not non_leaf_input: - use_reentrant = False - - # generate checkpoint function call in a new line - usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) - usage += '\n' - body.append(usage) - within_ckpt_region = False - - if idx in offload_ends: - within_offload_region = False - - if CODEGEN_AVAILABLE: - class ActivationCheckpointCodeGen(CodeGen): + class ChunkCodeGen(CodeGen): def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] @@ -796,10 +661,7 @@ def emit_node(node: Node, body): # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): - emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) - else: - emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -833,215 +695,3 @@ def emit_node(node: Node, body): {prologue} {code}""" return PythonCode(fn_code, globals_) - -else: - - def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: - """ - This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate - code for activation checkpoint. - """ - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' - - typename = _type_repr(o) - - # This is a generic type, e.g. typing.List[torch.Tensor] - if hasattr(o, '__origin__'): - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - return f'{origin_typename}[{",".join(args)}]' - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(self.nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == 'placeholder': - return - if user.op == 'output': - body.append('\n') - return - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') - else: - body.append('\n') - - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': - assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') - if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') - return - elif node.op == 'call_method': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') - return - elif node.op == 'call_function': - assert callable(node.target) - # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: - assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') - return - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') - return - body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): - wrapped_fns.setdefault(global_name) - return - elif node.op == 'call_module': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') - return - elif node.op == 'get_attr': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') - return - elif node.op == 'output': - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - if self._pytree_info is None: - body.append(f'return {repr(node.args[0])}') - else: - body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') - return - raise NotImplementedError(f'node: {node.op} {node.target}') - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): - emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) - else: - emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append('pass\n') - if self._pytree_info is not None: - orig_args = self._pytree_info.orig_args - has_orig_self = (orig_args[0] == 'self') - if has_orig_self: - free_vars.insert(0, 'self') - if len(free_vars) > 0: # pytree has placeholders in it - body.insert( - 0, - f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") - else: - orig_args = free_vars - - if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) - else: - wrap_stmts = '' - - ckpt_func = ''.join(ckpt_func) - - # If the original function didn't have self as its first argument, we - # would have added it. - if len(orig_args) == 0 or orig_args[0] != 'self': - orig_args.insert(0, 'self') - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) - - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - fn_code = f""" -{wrap_stmts} - -{ckpt_func} -def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: -{code}""" - return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 9ac399a29b51..85164bdada96 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -11,7 +11,7 @@ from colossalai.fx.graph_module import ColoGraphModule try: - from chunk_codegen import ActivationCheckpointCodeGen + from chunk_codegen import ChunkCodeGen with_codegen = True except: # fall back to older pytorch version @@ -75,7 +75,7 @@ def _run_offload_codegen(rank): # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(model) - codegen = ActivationCheckpointCodeGen() + codegen = ChunkCodeGen() graph.set_codegen(codegen) # annotate the activation offload part @@ -99,15 +99,7 @@ def _run_offload_codegen(rank): # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + print(code) _test_fwd_and_bwd(model, gm, data) gpc.destroy() @@ -118,60 +110,5 @@ def test_act_ckpt_codegen(): mp.spawn(_run_offload_codegen, nprocs=1) -def _run_offload_codegen_torch11(rank): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') - - # build model and input - model = MyNet().cuda() - data = torch.rand(4, 4).cuda() - - # trace the module and replace codegen - tracer = ColoTracer(trace_act_ckpt=True) - graph = tracer.trace(model) - - # replace a bound method of an object - graph._python_code = python_code_with_activation_checkpoint.__get__(graph) - - # annotate the activation offload part - # also annotate the activation_checkpoint so we could test both types - # of input offload - for node in graph.nodes: - if node.name == "linear0": - setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear1": - setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) - if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) - if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) - - gm = ColoGraphModule(copy.deepcopy(model), graph) - gm.recompile() - - # assert we have all the components - code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code - - _test_fwd_and_bwd(model, gm, data) - gpc.destroy() - - -@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") -def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_offload_codegen_torch11, nprocs=1) - - if __name__ == "__main__": _run_offload_codegen(0) From 78cfe4362b4550635f609a8b52a8489c7f9aa564 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 2 Nov 2022 13:59:48 +0800 Subject: [PATCH 003/209] basic chunk --- chunk_codegen.py | 66 ++++++++++++++++++++++---------------------- chunk_codegen_run.py | 15 +++++----- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 09fda2b988eb..c605e35f4725 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -46,6 +46,19 @@ def pack_hook_no_input(self, x): return pack_hook, unpack_hook +def _gen_loop_5(to_keep): + context = "chunk_result = []\nfor gen_loop_idx in range(4):\n" + context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n" + return context + + +def _gen_loop_5_final(final_name, to_keep): + context = " chunk_result.append(" + final_name + ")\n" + context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" + context += final_name + " = chunk_result; chunk_result = None\n" + return context + + def _gen_save_tensors_hooks_context(offload_input=True) -> str: """Generate customized saved_tensors_hooks @@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func - is_hook_inserted = False node_idx = 0 - while 1: + to_keep = [] + while node_idx < len(node_list): # break if we finish the processing all the nodes if node_idx >= len(node_list): break - # process ckpt_regions - if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) - node_idx += len(ckpt_node_list) - # process node in forward function else: node = node_list[node_idx] if node_idx in chunk_starts: - chunk_label = chunk_labels[chunk_starts.index(node_idx)] - _, chunk_input, chunk_bar = chunk_label + # save chunk input var, dont delete it + to_keep.extend(node.args[0].name) within_chunk_region = True - - # insert hook functions if needed - if not is_hook_inserted: - pack_hook, unpack_hook = _gen_saved_tensors_hooks() - ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") - is_hook_inserted = True - - if chunk_input and chunk_bar: - body.append(_gen_save_on_cpu_context()) - - elif chunk_input: - for par in chunk_inputs[chunk_label[0]]: - body.append(f"setattr({par}, 'offload', True)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=True)) - - else: - for par in chunk_inputs[chunk_label[0]]: - body.append(f"setattr({par}, 'offload', False)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=False)) + # add for loop + body.append(_gen_loop_5(to_keep[0])) + # change first node's input to new chunked var + node_args = list(node.args) + node_args[0] = 'chunk_tensor' if within_chunk_region: emit_node_func(node, body) body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body) + delete_unused_value_func(node, body, to_keep) else: emit_node_func(node, body) - delete_unused_value_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, to_keep) if node_idx in chunk_ends: + body.append(_gen_loop_5_final(node.name, to_keep)) + to_keep = [] within_chunk_region = False node_idx += 1 @@ -572,7 +568,7 @@ def register_last_uses(n: Node, user: Node): map_arg(node.kwargs, lambda n: register_last_uses(n, node)) # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body): + def delete_unused_values(user: Node, body, to_keep=[]): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage @@ -584,6 +580,9 @@ def delete_unused_values(user: Node, body): body.append('\n') return nodes_to_delete = user_to_last_uses.get(user, []) + for n in nodes_to_delete: + if n.name in to_keep: + nodes_to_delete.remove(n) if len(nodes_to_delete): to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) body.append(f'; {to_delete_str}\n') @@ -693,5 +692,6 @@ def emit_node(node: Node, body): {wrap_stmts} {prologue} -{code}""" +{code}""" + print(fn_code) return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 85164bdada96..69b327d4bd5b 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T # test forward non_fx_out = model(data) fx_out = gm(data) + print(non_fx_out.shape, fx_out.shape) assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" # test barckward @@ -86,13 +87,13 @@ def _run_offload_codegen(rank): setattr(node, "activation_offload", [0, True, False]) if node.name == "linear1": setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) - if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) - if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) + # if node.name == "linear2": + # setattr(node, "activation_offload", [1, True, True]) + # if node.name == "linear4": + # setattr(node, "activation_offload", [2, False, True]) + # if node.name == "linear5": + # setattr(node, "activation_checkpoint", [0]) + # setattr(node, "activation_offload", True) gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() From 86f2a3147415f2afe53019cd7b9d9414de1510e9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 2 Nov 2022 15:12:08 +0800 Subject: [PATCH 004/209] add evoformer --- evoformer/evoformer.py | 47 ++++++++++ evoformer/initializer.py | 29 ++++++ evoformer/kernel.py | 19 ++++ evoformer/msa.py | 95 +++++++++++++++++++ evoformer/ops.py | 176 +++++++++++++++++++++++++++++++++++ evoformer/triangle.py | 192 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 558 insertions(+) create mode 100644 evoformer/evoformer.py create mode 100755 evoformer/initializer.py create mode 100644 evoformer/kernel.py create mode 100644 evoformer/msa.py create mode 100755 evoformer/ops.py create mode 100644 evoformer/triangle.py diff --git a/evoformer/evoformer.py b/evoformer/evoformer.py new file mode 100644 index 000000000000..ef3df2769840 --- /dev/null +++ b/evoformer/evoformer.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +from .msa import MSAStack +from .ops import OutProductMean +from .triangle import PairStack + + +class EvoformerBlock(nn.Module): + + def __init__(self, d_node, d_pair): + super(EvoformerBlock, self).__init__() + + self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) + self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) + self.pair_stack = PairStack(d_pair=d_pair) + + def forward(self, node, pair): + node = node + self.msa_stack(node, pair) + pair = pair + self.communication(node) + pair = pair + self.pair_stack(pair) + return node, pair + + +class Evoformer(nn.Module): + + def __init__(self, d_node, d_pair): + super(Evoformer, self).__init__() + + self.blocks = nn.ModuleList() + for _ in range(3): + self.blocks.append(EvoformerBlock(d_node, d_pair)) + + def forward(self, node, pair): + for b in self.blocks: + node, pair = b(node, pair) + return node, pair + +def evoformer_base(): + return Evoformer(d_node=256, d_pair=128) + + +def evoformer_large(): + return Evoformer(d_node=512, d_pair=256) + + +__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer/initializer.py b/evoformer/initializer.py new file mode 100755 index 000000000000..c6ce0659e597 --- /dev/null +++ b/evoformer/initializer.py @@ -0,0 +1,29 @@ +import math + +import numpy as np +import torch.nn as nn + + +def glorot_uniform_af(x, gain=1.0): + """ + initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: + In PyTorch: + [feature_out, feature_in, n_head ...] + In Jax: + [... n_head, feature_in, feature_out] + However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: + [feature_in, n_head, feature_out] + + In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors + """ + fan_in, fan_out = x.shape[-2:] + if len(x.shape) > 2: + receptive_field_size = np.prod(x.shape[:-2]) + fan_in *= receptive_field_size + fan_out *= receptive_field_size + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + nn.init.uniform_(x, -dev, dev) + + return x diff --git a/evoformer/kernel.py b/evoformer/kernel.py new file mode 100644 index 000000000000..2655901a2fe9 --- /dev/null +++ b/evoformer/kernel.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + + +def bias_sigmod_ele(y, bias, z): + return torch.sigmoid(y + bias) * z + + +def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, + residual: torch.Tensor, prob: float) -> torch.Tensor: + out = (x + bias) * F.dropout(dropmask, p=prob, training=True) + out = residual + out + return out + + +def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + dropout_mask: torch.Tensor, Z_raw: torch.Tensor, + prob: float) -> torch.Tensor: + return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer/msa.py b/evoformer/msa.py new file mode 100644 index 000000000000..ccefa38c48be --- /dev/null +++ b/evoformer/msa.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add +from .ops import SelfAttention, Transition + + +class MSARowAttentionWithPairBias(nn.Module): + + def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): + super(MSARowAttentionWithPairBias, self).__init__() + self.d_node = d_node + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernormM = LayerNorm(d_node) + self.layernormZ = LayerNorm(d_pair) + + _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) + + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) + + def forward(self, M_raw, Z): + ## Input projections + M = self.layernormM(M_raw) + Z = self.layernormZ(Z) + b = F.linear(Z, self.linear_b_weights) + b = b.permute(0, 3, 1, 2) + # b = rearrange(b, 'b q k h -> b h q k') + + M = self.attention(M, b) + dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype) + + return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) + + +class MSAColumnAttention(nn.Module): + + def __init__(self, d_node, c=32, n_head=8): + super(MSAColumnAttention, self).__init__() + self.d_node = d_node + self.c = c + self.n_head = n_head + + self.layernormM = LayerNorm(d_node) + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True) + + def forward(self, M_raw): + M = M_raw.transpose(-2, -3) + M = self.layernormM(M) + + M = self.attention(M) + + M = M.transpose(-2, -3) + return M_raw + M + + +class MSAStack(nn.Module): + + def __init__(self, d_node, d_pair, p_drop=0.15): + super(MSAStack, self).__init__() + + self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, + d_pair=d_pair, + p_drop=p_drop) + + self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) + self.MSATransition = Transition(d=d_node) + + def forward(self, node, pair): + node = self.MSARowAttentionWithPairBias(node, pair) + node = self.MSAColumnAttention(node) + node = self.MSATransition(node) + + return node diff --git a/evoformer/ops.py b/evoformer/ops.py new file mode 100755 index 000000000000..ddbba441dd5f --- /dev/null +++ b/evoformer/ops.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .initializer import glorot_uniform_af +from .kernel import bias_sigmod_ele + + +class DropoutRowwise(nn.Module): + + def __init__(self, p): + super(DropoutRowwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, 0:1, :, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class DropoutColumnwise(nn.Module): + + def __init__(self, p): + super(DropoutColumnwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, :, 0:1, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class Transition(nn.Module): + + def __init__(self, d, n=4): + super(Transition, self).__init__() + self.norm = LayerNorm(d) + self.linear1 = Linear(d, n * d, initializer='relu') + self.linear2 = Linear(n * d, d, initializer='zeros') + + def forward(self, src): + x = self.norm(src) + x = self.linear2(F.relu(self.linear1(x))) + return src + x + + +class OutProductMean(nn.Module): + + def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): + super(OutProductMean, self).__init__() + + self.layernormM = LayerNorm(n_feat) + self.linear_a = Linear(n_feat, n_feat_proj) + self.linear_b = Linear(n_feat, n_feat_proj) + + self.o_linear = Linear(n_feat_proj * n_feat_proj, + n_feat_out, + initializer='zero', + use_bias=True) + + def forward(self, M): + M = self.layernormM(M) + left_act = self.linear_a(M) + right_act = self.linear_b(M) + + O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + # O = rearrange(O, 'b i j d e -> b i j (d e)') + O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) + Z = self.o_linear(O) + + return Z + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + feature_in: int, + feature_out: int, + initializer: str = 'linear', + use_bias: bool = True, + bias_init: float = 0., + ): + super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) + + self.use_bias = use_bias + if initializer == 'linear': + glorot_uniform_af(self.weight, gain=1.0) + elif initializer == 'relu': + glorot_uniform_af(self.weight, gain=2.0) + elif initializer == 'zeros': + nn.init.zeros_(self.weight) + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(bias_init) + + +class SelfAttention(nn.Module): + """ + Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors + """ + + def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): + super(SelfAttention, self).__init__() + self.qkv_dim = qkv_dim + self.c = c + self.n_head = n_head + self.out_dim = out_dim + self.gating = gating + self.last_bias_fuse = last_bias_fuse + + self.scaling = self.c**(-0.5) + + # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') + self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + + if gating: + self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) + self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) + + self.o_linear = Linear(n_head * c, + out_dim, + initializer='zero', + use_bias=(not last_bias_fuse)) + + def forward(self, in_data, nonbatched_bias=None): + """ + :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] + :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] + :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] + """ + + # qkv = self.to_qkv(in_data).chunk(3, dim=-1) + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) + + q = self.to_q(in_data) + k = self.to_k(in_data) + v = self.to_k(in_data) + + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), + # [q, k, v]) + q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), + [q, k, v]) + + q = q * self.scaling + + logits = torch.matmul(q, k.transpose(-1, -2)) + + if nonbatched_bias is not None: + logits += nonbatched_bias.unsqueeze(1) + weights = torch.softmax(logits, dim=-1) + # weights = softmax(logits) + + weighted_avg = torch.matmul(weights, v) + # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') + weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) + + if self.gating: + gate_values = self.gating_linear(in_data) + weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) + + output = self.o_linear(weighted_avg) + return output diff --git a/evoformer/triangle.py b/evoformer/triangle.py new file mode 100644 index 000000000000..7db0482f5557 --- /dev/null +++ b/evoformer/triangle.py @@ -0,0 +1,192 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add, bias_ele_dropout_residual +from .ops import Linear, SelfAttention, Transition + + +def permute_final_dims(tensor, inds): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +class TriangleMultiplicationOutgoing(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationOutgoing, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 0, 1)), + # permute_final_dims(right_proj_act, (2, 1, 0)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleMultiplicationIncoming(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationIncoming, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 1, 0)), + # permute_final_dims(right_proj_act, (2, 0, 1)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleAttentionStartingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionStartingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class TriangleAttentionEndingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionEndingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = Z_raw.transpose(-2, -3) + Z = self.layernorm1(Z) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + Z = Z.transpose(-2, -3) + dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class PairStack(nn.Module): + + def __init__(self, d_pair, p_drop=0.25): + super(PairStack, self).__init__() + + self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) + self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) + self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) + self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) + self.PairTransition = Transition(d=d_pair) + + def forward(self, pair): + pair = self.TriangleMultiplicationOutgoing(pair) + pair = self.TriangleMultiplicationIncoming(pair) + pair = self.TriangleAttentionStartingNode(pair) + pair = self.TriangleAttentionEndingNode(pair) + pair = self.PairTransition(pair) + return pair From 820ea4d056e4ca943ca1d143325fb582128a1b96 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 2 Nov 2022 15:49:25 +0800 Subject: [PATCH 005/209] align evoformer --- chunk_codegen.py | 143 ++++++----------------------------------- chunk_codegen_run.py | 97 ++++++++++------------------ evoformer/evoformer.py | 7 +- evoformer/kernel.py | 2 +- evoformer/msa.py | 2 +- evoformer/triangle.py | 8 +-- 6 files changed, 67 insertions(+), 192 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index c605e35f4725..cb2a3a8a90ee 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1,5 +1,6 @@ import colossalai import torch +import copy from typing import List, Callable, Any, Tuple, Dict, Iterable try: @@ -17,74 +18,18 @@ __all__ = ['python_code_with_activation_checkpoint'] -def _gen_saved_tensors_hooks(): - """ - Generate saved tensors hooks - """ - - pack_hook = """def pack_hook_input(self, x): - if getattr(x, "offload", False): - return (x.device, x.cpu()) - else: - return x - -def pack_hook_no_input(self, x): - if getattr(x, "offload", True): - return (x.device, x.cpu()) - else: - return x -""" - - unpack_hook = """def unpack_hook(self, packed): - if isinstance(packed, tuple): - device, tensor = packed - return tensor.to(device) - else: - return packed -""" - - return pack_hook, unpack_hook - - -def _gen_loop_5(to_keep): - context = "chunk_result = []\nfor gen_loop_idx in range(4):\n" - context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n" +def _gen_loop_start(to_keep, chunk_size=2): + context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0]) + context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n" return context -def _gen_loop_5_final(final_name, to_keep): +def _gen_loop_end(final_name, to_keep): context = " chunk_result.append(" + final_name + ")\n" context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" context += final_name + " = chunk_result; chunk_result = None\n" return context - -def _gen_save_tensors_hooks_context(offload_input=True) -> str: - """Generate customized saved_tensors_hooks - - Args: - offload_input (bool, optional): whether we need offload input, if offload_input=False, - we will use self.pack_hook_no_input instead. Defaults to True. - - Returns: - str: generated context - """ - - if offload_input: - context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" - else: - context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" - return context - - -def _gen_save_on_cpu_context(): - """ - Generate save on cpu context - """ - - context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" - return context - def _find_input_and_output_nodes(nodes: List[Node]): """ @@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]): return input_nodes, output_nodes -def _find_ckpt_regions(nodes: List[Node]): - """ - Find the checkpoint regions given a list of consecutive nodes. The outputs will be list - of tuples, each tuple is in the form of (start_index, end_index). - """ - ckpt_nodes = [] - ckpt_regions = [] - start = -1 - end = -1 - current_region = None - - for idx, node in enumerate(nodes): - if hasattr(node, 'activation_checkpoint'): - act_ckpt_label = node.activation_checkpoint - - # this activation checkpoint label is not set yet - # meaning this is the first node of the activation ckpt region - if current_region is None: - current_region = act_ckpt_label - start = idx - - # if activation checkpoint has changed - # we restart the tracking - # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] - if act_ckpt_label != current_region: - assert start != -1 - ckpt_regions.append((start, idx - 1)) - current_region = act_ckpt_label - start = idx - end = -1 - elif current_region is not None and not hasattr(node, 'activation_checkpoint'): - # used to check the case below - # node ckpt states = [ckpt, ckpt, non-ckpt] - end = idx - 1 - assert start != -1 and end != -1 - ckpt_regions.append((start, end)) - start = end = -1 - current_region = None - else: - pass - return ckpt_regions - - def _find_offload_regions(nodes: List[Node]): """This function is to find the offload regions In pofo algorithm, during annotation, we will annotate the offload region with the @@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v emit_node_func: function to emit node delete_unused_value_func: function to remove the unused value """ - ckpt_regions = _find_nested_ckpt_regions(nodes, 0) - start_idx = [item[0] for item in ckpt_regions] - end_idx = [item[1] for item in ckpt_regions] # find the offload regions - chunk_regions, chunk_labels = _find_offload_regions(nodes) + chunk_regions = [(1, 4)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func node_idx = 0 - to_keep = [] + chunk_var = [] while node_idx < len(node_list): # break if we finish the processing all the nodes if node_idx >= len(node_list): @@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node = node_list[node_idx] if node_idx in chunk_starts: - # save chunk input var, dont delete it - to_keep.extend(node.args[0].name) within_chunk_region = True - # add for loop - body.append(_gen_loop_5(to_keep[0])) - # change first node's input to new chunked var - node_args = list(node.args) - node_args[0] = 'chunk_tensor' + # save chunk input var, dont delete it + chunk_var.append(node.args[0].name) + + # add for loop + body.append(_gen_loop_start(chunk_var[0])) + if within_chunk_region: emit_node_func(node, body) + # replace input var with chunk var + if node_idx in chunk_starts: + body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)') body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body, to_keep) + delete_unused_value_func(node, body, chunk_var) else: emit_node_func(node, body) if node_idx not in chunk_inputs: - delete_unused_value_func(node, body, to_keep) + delete_unused_value_func(node, body, chunk_var) if node_idx in chunk_ends: - body.append(_gen_loop_5_final(node.name, to_keep)) - to_keep = [] + body.append(_gen_loop_end(node.name, chunk_var)) + chunk_var = [] within_chunk_region = False node_idx += 1 @@ -580,9 +481,7 @@ def delete_unused_values(user: Node, body, to_keep=[]): body.append('\n') return nodes_to_delete = user_to_last_uses.get(user, []) - for n in nodes_to_delete: - if n.name in to_keep: - nodes_to_delete.remove(n) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] if len(nodes_to_delete): to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) body.append(f'; {to_delete_str}\n') diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 69b327d4bd5b..7667fa691558 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -9,60 +9,39 @@ from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule - -try: - from chunk_codegen import ChunkCodeGen - with_codegen = True -except: - # fall back to older pytorch version - from chunk_codegen import python_code_with_activation_checkpoint - with_codegen = False - - -class MyNet(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.linear0 = torch.nn.Linear(4, 4) - self.linear1 = torch.nn.Linear(4, 4) - self.linear2 = torch.nn.Linear(4, 4) - self.linear3 = torch.nn.Linear(4, 4) - self.linear4 = torch.nn.Linear(4, 4) - self.linear5 = torch.nn.Linear(4, 4) - self.linear6 = torch.nn.Linear(4, 4) - - def forward(self, x): - x = self.linear0(x) - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = self.linear4(x) - x = self.linear5(x) - x = self.linear6(x) - return x +from evoformer.evoformer import evoformer_base +from chunk_codegen import ChunkCodeGen +with_codegen = True def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if not torch.allclose(m_p.grad, gm_p.grad): + if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad): return False return True -def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): +def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data): + return False + return True + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): # test forward - non_fx_out = model(data) - fx_out = gm(data) - print(non_fx_out.shape, fx_out.shape) - assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + non_fx_out = model(node.clone(), pair.clone()) + fx_out = gm(node.clone(), pair.clone()) + assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output" + assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output" # test barckward - loss0 = non_fx_out.sum() - loss0.backward() - loss1 = fx_out.sum() - loss1.backward() - assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() + # loss0.backward() + # loss1 = fx_out[0].sum() + fx_out[1].sum() + # loss1.backward() + # assert _is_all_param_close(model, gm) + # assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" def _run_offload_codegen(rank): @@ -70,30 +49,22 @@ def _run_offload_codegen(rank): colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') # build model and input - model = MyNet().cuda() - data = torch.rand(4, 4).cuda() + model = evoformer_base().cuda() + node = torch.randn(1, 16, 32, 256).cuda() + pair = torch.randn(1, 32, 32, 128).cuda() # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(model) - codegen = ChunkCodeGen() - graph.set_codegen(codegen) - - # annotate the activation offload part - # also annotate the activation_checkpoint so we could test both types - # of input offload - for node in graph.nodes: - if node.name == "linear0": - setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear1": - setattr(node, "activation_offload", [0, True, False]) - # if node.name == "linear2": - # setattr(node, "activation_offload", [1, True, True]) - # if node.name == "linear4": - # setattr(node, "activation_offload", [2, False, True]) - # if node.name == "linear5": - # setattr(node, "activation_checkpoint", [0]) - # setattr(node, "activation_offload", True) + # codegen = ChunkCodeGen() + # graph.set_codegen(codegen) + + # annotate the chunk part + # for node in graph.nodes: + # if node.name == "linear0": + # setattr(node, "activation_offload", [0, True, False]) + # if node.name == "linear1": + # setattr(node, "activation_offload", [0, True, False]) gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() @@ -102,7 +73,7 @@ def _run_offload_codegen(rank): code = graph.python_code("self").src print(code) - _test_fwd_and_bwd(model, gm, data) + _test_fwd_and_bwd(model, gm, node, pair) gpc.destroy() diff --git a/evoformer/evoformer.py b/evoformer/evoformer.py index ef3df2769840..0c5ab952a779 100644 --- a/evoformer/evoformer.py +++ b/evoformer/evoformer.py @@ -28,7 +28,7 @@ def __init__(self, d_node, d_pair): super(Evoformer, self).__init__() self.blocks = nn.ModuleList() - for _ in range(3): + for _ in range(1): self.blocks.append(EvoformerBlock(d_node, d_pair)) def forward(self, node, pair): @@ -36,6 +36,11 @@ def forward(self, node, pair): node, pair = b(node, pair) return node, pair + +def evoformer_tiny(): + return Evoformer(d_node=64, d_pair=32) + + def evoformer_base(): return Evoformer(d_node=256, d_pair=128) diff --git a/evoformer/kernel.py b/evoformer/kernel.py index 2655901a2fe9..26ab5dc53261 100644 --- a/evoformer/kernel.py +++ b/evoformer/kernel.py @@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z): def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=True) + out = (x + bias) * F.dropout(dropmask, p=prob, training=False) out = residual + out return out diff --git a/evoformer/msa.py b/evoformer/msa.py index ccefa38c48be..cac456638a55 100644 --- a/evoformer/msa.py +++ b/evoformer/msa.py @@ -45,7 +45,7 @@ def forward(self, M_raw, Z): # b = rearrange(b, 'b q k h -> b h q k') M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype) + dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) diff --git a/evoformer/triangle.py b/evoformer/triangle.py index 7db0482f5557..f479469c3836 100644 --- a/evoformer/triangle.py +++ b/evoformer/triangle.py @@ -51,7 +51,7 @@ def forward(self, Z_raw): ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) return bias_ele_dropout_residual(ab, self.output_bias, g, @@ -97,7 +97,7 @@ def forward(self, Z_raw): ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) return bias_ele_dropout_residual(ab, self.output_bias, g, @@ -134,7 +134,7 @@ def forward(self, Z_raw): Z = self.attention(Z, b) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) @@ -168,7 +168,7 @@ def forward(self, Z_raw): Z = self.attention(Z, b) Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype) + dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) From f8aeecef46461ff574f51982d03310fa8c57888e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 3 Nov 2022 14:33:35 +0800 Subject: [PATCH 006/209] add meta --- chunk_codegen.py | 3 +++ chunk_codegen_run.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index cb2a3a8a90ee..1f336eb2bf35 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -366,6 +366,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if CODEGEN_AVAILABLE: class ChunkCodeGen(CodeGen): + def __init__(self, meta_graph): + super().__init__() + self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 7667fa691558..b875b6308f55 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -9,6 +9,8 @@ from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.profiler import MetaTensor from evoformer.evoformer import evoformer_base from chunk_codegen import ChunkCodeGen with_codegen = True @@ -56,9 +58,10 @@ def _run_offload_codegen(rank): # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(model) - # codegen = ChunkCodeGen() - # graph.set_codegen(codegen) - + gm_prop = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm_prop) + interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + # annotate the chunk part # for node in graph.nodes: # if node.name == "linear0": @@ -66,7 +69,9 @@ def _run_offload_codegen(rank): # if node.name == "linear1": # setattr(node, "activation_offload", [0, True, False]) - gm = ColoGraphModule(copy.deepcopy(model), graph) + codegen = ChunkCodeGen(gm_prop) + # graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) gm.recompile() # assert we have all the components From c35718e8db5f3fbbb5749a2a0b5f4b46241a43b1 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 4 Nov 2022 11:18:09 +0800 Subject: [PATCH 007/209] basic chunk --- chunk_codegen.py | 138 +++++++++++++++++++++++++++++-------------- chunk_codegen_run.py | 2 +- 2 files changed, 95 insertions(+), 45 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1f336eb2bf35..1267f64cbbb2 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -18,16 +18,61 @@ __all__ = ['python_code_with_activation_checkpoint'] -def _gen_loop_start(to_keep, chunk_size=2): - context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0]) - context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n" +def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): + new_shape = "[" + for idx, i in enumerate(shape): + if idx == chunk_dim: + new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) + else: + new_shape += ":" + new_shape += ", " + new_shape = new_shape[:-2] + "]" + return new_shape + + +def _get_first_non_single_dim(shape): + for idx, i in enumerate(shape): + if i == 1: + continue + else: + return idx + raise RuntimeError("can not get first non single dim for shape", shape) + + +def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): + if len(chunk_input_meta) == 1: + node = chunk_input_meta[0] + node_shape = node.meta['tensor_meta'].shape + chunk_dim = _get_first_non_single_dim(node_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) + out_shape = str(list(chunk_output.meta['tensor_meta'].shape)) + + context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % ( + out_shape, node.name, node.name, chunk_size) + context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) + context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) + else: + raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta)) return context -def _gen_loop_end(final_name, to_keep): - context = " chunk_result.append(" + final_name + ")\n" - context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" - context += final_name + " = chunk_result; chunk_result = None\n" +def _gen_loop_end(chunk_outputs, chunk_inputs, node_list): + chunk_inputs_name = chunk_inputs[0].name + chunk_outputs_name = chunk_outputs.name + chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) + chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape + chunk_dim = _get_first_non_single_dim(chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) + context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) + + context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + + # determine if its the last use for chunk input + users_name = list(chunk_inputs[0].users.keys()) + if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]): + context += "; %s = None" % chunk_inputs_name + + context += "\n" return context @@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for input_node in node._input_nodes.keys(): node_repr = repr(input_node) if input_node not in nodes and node_repr not in input_nodes: - input_nodes.append(node_repr) + input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output @@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]): for output_node in node.users.keys(): node_repr = repr(node) if output_node not in nodes and node_repr not in output_nodes: - output_nodes.append(node_repr) + output_nodes.append(output_node) return input_nodes, output_nodes +def _find_idx_by_name(name, nodes_list): + for idx, node in enumerate(nodes_list): + if node.name == name: + return idx + raise RuntimeError("name %s not found in node list" % name) + + def _find_offload_regions(nodes: List[Node]): """This function is to find the offload regions In pofo algorithm, during annotation, we will annotate the offload region with the @@ -290,7 +342,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(1, 4)] + chunk_regions = [(2, 5)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v inputs, outputs = _find_input_and_output_nodes(offload_node_list) chunk_inputs.append(inputs) chunk_outputs.append(outputs) - + chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] + chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] + chunk_inputs_names = [] + for i in chunk_inputs: + for j in i: + chunk_inputs_names.append(j.name) + # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func node_idx = 0 - chunk_var = [] + region_idx = 0 while node_idx < len(node_list): - # break if we finish the processing all the nodes - if node_idx >= len(node_list): - break + node = node_list[node_idx] - # process node in forward function - else: - node = node_list[node_idx] + if node_idx in chunk_starts: + within_chunk_region = True + + # add for loop + chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] + body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]])) + if within_chunk_region: + emit_node_func(node, body) + # replace input var with chunk var if node_idx in chunk_starts: - within_chunk_region = True - - # save chunk input var, dont delete it - chunk_var.append(node.args[0].name) - - # add for loop - body.append(_gen_loop_start(chunk_var[0])) - - if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - if node_idx in chunk_starts: - body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)') - body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body, chunk_var) + body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)') + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) - else: - emit_node_func(node, body) - if node_idx not in chunk_inputs: - delete_unused_value_func(node, body, chunk_var) + else: + emit_node_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, chunk_inputs_names) - if node_idx in chunk_ends: - body.append(_gen_loop_end(node.name, chunk_var)) - chunk_var = [] - within_chunk_region = False + if node_idx in chunk_ends: + body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list)) + within_chunk_region = False + region_idx += 1 - node_idx += 1 + node_idx += 1 if CODEGEN_AVAILABLE: @@ -562,7 +612,7 @@ def emit_node(node: Node, body): # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index b875b6308f55..547b983a9c0c 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -70,7 +70,7 @@ def _run_offload_codegen(rank): # setattr(node, "activation_offload", [0, True, False]) codegen = ChunkCodeGen(gm_prop) - # graph.set_codegen(codegen) + graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() From d95cfe26222427e483df7f23f4bb208cec6ae4c3 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 7 Nov 2022 18:26:13 +0800 Subject: [PATCH 008/209] basic memory --- chunk_codegen.py | 83 ++++++++++++++++++++++++++++++++++++++++++-- chunk_codegen_run.py | 20 +++++------ 2 files changed, 90 insertions(+), 13 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1267f64cbbb2..4ca33a4d5914 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -6,6 +6,7 @@ try: from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin + from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size CODEGEN_AVAILABLE = True except: from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin @@ -18,6 +19,82 @@ __all__ = ['python_code_with_activation_checkpoint'] +def _get_meta_node_size(x): + x = x.meta['tensor_meta'] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x + + +def _get_output_node_size(n): + fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + return activation_size(fwd_out) + + +def _get_delete_node_size(user, user_to_last_uses): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete]) + return delete_size + return 0 + + +def _get_last_usr(nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses + + +def _estimate_inference_mem(gm: torch.fx.GraphModule): + act_memory = 0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + for node in gm.graph.nodes: + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += _get_meta_node_size(node) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + act_memory += calculate_fwd_tmp(node) + # act_memory += calculate_fwd_out(node) + act_memory += _get_output_node_size(node) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= calculate_fwd_tmp(node) + act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory_after_node_log.append(act_memory) + + act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + param_memory = parameter_size(gm) + return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + + +def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size): + node_size = 0 + param_size = 0 + for node in gm.graph.nodes: + node_size += calculate_fwd_tmp(node) + node_size += calculate_fwd_out(node) + param_size = parameter_size(gm) + return (node_size + param_size) / 1024**2, param_size / 1024**2 + + def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): new_shape = "[" for idx, i in enumerate(shape): @@ -342,7 +419,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + _estimate_inference_mem(meta_graph) # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): @@ -418,6 +496,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v class ChunkCodeGen(CodeGen): def __init__(self, meta_graph): super().__init__() + self.meta_graph = meta_graph self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: @@ -612,7 +691,7 @@ def emit_node(node: Node, body): # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 547b983a9c0c..1ab7d958b0a9 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import pytest +import torch.fx import torch.multiprocessing as mp from torch.fx import GraphModule from colossalai.fx import ColoTracer @@ -56,18 +57,15 @@ def _run_offload_codegen(rank): pair = torch.randn(1, 32, 32, 128).cuda() # trace the module and replace codegen - tracer = ColoTracer(trace_act_ckpt=True) - graph = tracer.trace(model) - gm_prop = torch.fx.GraphModule(model, graph) - interp = MetaInfoProp(gm_prop) + graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + + # now run it twice to get meta info in graph module, not necessary + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) - - # annotate the chunk part - # for node in graph.nodes: - # if node.name == "linear0": - # setattr(node, "activation_offload", [0, True, False]) - # if node.name == "linear1": - # setattr(node, "activation_offload", [0, True, False]) codegen = ChunkCodeGen(gm_prop) graph.set_codegen(codegen) From 12301dd2e9a1889fe76c6ab719aff1404e92aea0 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 8 Nov 2022 10:34:14 +0800 Subject: [PATCH 009/209] finish basic inference memory estimation --- chunk_codegen.py | 11 +++++++++++ chunk_codegen_run.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 4ca33a4d5914..01b29cb33d43 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -64,6 +64,8 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): # if node is placeholder, just add the size of the node if node.op == 'placeholder': act_memory += _get_meta_node_size(node) + act_memory_peak_log.append(act_memory) + act_memory_after_node_log.append(act_memory) # skip output elif node.op == 'output': continue @@ -81,6 +83,15 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): act_memory_after_node_log.append(act_memory) act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] + + # for i in act_memory_peak_log: + # print("%.2f " % i, end='') + # print("\n") + # for i in act_memory_after_node_log: + # print("%.2f " % i, end='') + # print("\n") + param_memory = parameter_size(gm) return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 1ab7d958b0a9..cc975f2eaf84 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,9 +32,19 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)) + # with torch.no_grad(): + # fx_out = gm(node, pair) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem)) + # test forward - non_fx_out = model(node.clone(), pair.clone()) - fx_out = gm(node.clone(), pair.clone()) + with torch.no_grad(): + non_fx_out = model(node, pair) + fx_out = gm(node, pair) assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output" assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output" From 8cca684c5684ffb0ac0b68d63df3cbde848d3d08 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 8 Nov 2022 14:41:57 +0800 Subject: [PATCH 010/209] finish memory estimation --- chunk_codegen.py | 103 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 01b29cb33d43..baf207795b60 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -85,25 +85,97 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - # for i in act_memory_peak_log: - # print("%.2f " % i, end='') - # print("\n") - # for i in act_memory_after_node_log: - # print("%.2f " % i, end='') - # print("\n") + print("no chunk") + _print_mem_log(act_memory_peak_log, "peak") + _print_mem_log(act_memory_after_node_log, "after") param_memory = parameter_size(gm) return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) -def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size): - node_size = 0 - param_size = 0 - for node in gm.graph.nodes: - node_size += calculate_fwd_tmp(node) - node_size += calculate_fwd_out(node) - param_size = parameter_size(gm) - return (node_size + param_size) / 1024**2, param_size / 1024**2 +def _get_chunk_ratio(node, chunk_dim, chunk_size): + shape = node.meta['tensor_meta'].shape + chunk_ratio = float(chunk_size) / shape[chunk_dim] + return chunk_ratio + + +def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + node_idx = _find_idx_by_name(n.name, node_list) + if start_node <= node_idx < end_node: + delete_size += _get_output_node_size(n) * chunk_ratio + return delete_size + + +def _print_mem_log(log, title=None): + if title: + print("%-8s" % title, end=' ') + for i in log: + print("%.2f " % i, end='') + print("") + + +def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + act_memory = 0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + within_chunk = False + region_idx = 0 + chunk_ratio = 1 # use it to estimate chunk mem + node_list = list(gm.graph.nodes) + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if idx in start_nodes: + within_chunk = True + chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) + act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) + + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += _get_meta_node_size(node) * chunk_ratio + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + act_memory += calculate_fwd_tmp(node) * chunk_ratio + # act_memory += calculate_fwd_out(node) + act_memory += _get_output_node_size(node) * chunk_ratio + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= calculate_fwd_tmp(node) * chunk_ratio + if within_chunk: + act_memory -= _get_chunk_delete_node_size( + node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx]) + else: + act_memory -= _get_delete_node_size(node, user_to_last_uses) + + if idx in end_nodes: + act_memory -= _get_output_node_size(node) * chunk_ratio + within_chunk = False + chunk_ratio = 1 + region_idx += 1 + + act_memory_after_node_log.append(act_memory) + + act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] + + print("chunk") + _print_mem_log(act_memory_peak_log, "peak") + _print_mem_log(act_memory_after_node_log, "after") + + param_memory = parameter_size(gm) + return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -444,7 +516,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(2, 5)] + chunk_regions = [(2, 6)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -452,6 +524,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) _estimate_inference_mem(meta_graph) # find the input and output var names for each offload region From 22f9c60b6bea147c38127f5a4420a91ab73dc84b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 9 Nov 2022 17:50:39 +0800 Subject: [PATCH 011/209] fix bug --- evoformer/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evoformer/ops.py b/evoformer/ops.py index ddbba441dd5f..611b7b0fe777 100755 --- a/evoformer/ops.py +++ b/evoformer/ops.py @@ -147,7 +147,7 @@ def forward(self, in_data, nonbatched_bias=None): q = self.to_q(in_data) k = self.to_k(in_data) - v = self.to_k(in_data) + v = self.to_v(in_data) # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), # [q, k, v]) From d7634af5c031aa9f4faaf6ee5ea0c1662d6c6f25 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 11 Nov 2022 15:43:03 +0800 Subject: [PATCH 012/209] finish memory estimation --- chunk_codegen.py | 107 ++++++++++++++++++++++++++++--------------- chunk_codegen_run.py | 20 ++++---- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index baf207795b60..c8bb433ef6b5 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -55,15 +55,49 @@ def register_last_uses(n: Node, user: Node): return user_to_last_uses +def _delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == 'placeholder': + user_to_last_uses[key].remove(n) + + +def _get_contiguous_memory(node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ['transpose', 'permute'] + + if node.op == 'call_function' and 'matmul' in node.name: + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += _get_output_node_size(n) + elif node.op == 'call_module': + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + elif any(i in node.args for i in not_contiguous_list): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + + return mem + + def _estimate_inference_mem(gm: torch.fx.GraphModule): - act_memory = 0 + act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + not_contiguous_list = [] user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) for node in gm.graph.nodes: # if node is placeholder, just add the size of the node if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) + act_memory += _get_meta_node_size(node) / (1024 ** 2) act_memory_peak_log.append(act_memory) act_memory_after_node_log.append(act_memory) # skip output @@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - act_memory += calculate_fwd_tmp(node) - # act_memory += calculate_fwd_out(node) - act_memory += _get_output_node_size(node) + act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) + act_memory += _get_output_node_size(node) / (1024 ** 2) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= calculate_fwd_tmp(node) - act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) act_memory_after_node_log.append(act_memory) - act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] - act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - print("no chunk") - _print_mem_log(act_memory_peak_log, "peak") - _print_mem_log(act_memory_after_node_log, "after") + _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") + _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") param_memory = parameter_size(gm) - return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + return act_memory + param_memory, param_memory def _get_chunk_ratio(node, chunk_dim, chunk_size): @@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, return delete_size -def _print_mem_log(log, title=None): +def _print_mem_log(log, nodes, title=None): if title: - print("%-8s" % title, end=' ') - for i in log: - print("%.2f " % i, end='') - print("") + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): - act_memory = 0 + act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + not_contiguous_list = [] user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) within_chunk = False region_idx = 0 chunk_ratio = 1 # use it to estimate chunk mem @@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod if idx in start_nodes: within_chunk = True chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) + act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) # if node is placeholder, just add the size of the node if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) * chunk_ratio + act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2) act_memory_peak_log.append(act_memory) # skip output elif node.op == 'output': @@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - act_memory += calculate_fwd_tmp(node) * chunk_ratio - # act_memory += calculate_fwd_out(node) - act_memory += _get_output_node_size(node) * chunk_ratio + act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) + act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= calculate_fwd_tmp(node) * chunk_ratio + act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) if within_chunk: act_memory -= _get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx]) + node, user_to_last_uses, chunk_ratio, node_list, + start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) else: - act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) if idx in end_nodes: - act_memory -= _get_output_node_size(node) * chunk_ratio + act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2) within_chunk = False chunk_ratio = 1 region_idx += 1 act_memory_after_node_log.append(act_memory) - act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] - act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - print("chunk") - _print_mem_log(act_memory_peak_log, "peak") - _print_mem_log(act_memory_after_node_log, "after") - + _print_mem_log(act_memory_peak_log, node_list, "peak") + _print_mem_log(act_memory_after_node_log, node_list, "after") + param_memory = parameter_size(gm) - return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + return act_memory + param_memory, param_memory def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(2, 6)] + chunk_regions = [(58, 62)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -683,7 +714,9 @@ def register_last_uses(n: Node, user: Node): for node in reversed(nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - + + _delete_free_var_from_last_use(user_to_last_uses) + # NOTE: we add a variable to distinguish body and ckpt_func def delete_unused_values(user: Node, body, to_keep=[]): """ diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index cc975f2eaf84..39363a80abcb 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)) - # with torch.no_grad(): - # fx_out = gm(node, pair) - # new_now_mem = torch.cuda.memory_allocated() / 1024**2 - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem)) + now_mem = torch.cuda.memory_allocated() / 1024**2 + with torch.no_grad(): + node0 = node.clone() + pair0 = pair.clone() + node1, pair1 = gm(node0, pair0) + new_now_mem = torch.cuda.memory_allocated() / 1024**2 + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) # test forward with torch.no_grad(): @@ -63,8 +63,8 @@ def _run_offload_codegen(rank): # build model and input model = evoformer_base().cuda() - node = torch.randn(1, 16, 32, 256).cuda() - pair = torch.randn(1, 32, 32, 128).cuda() + node = torch.randn(1, 100, 300, 256).cuda() + pair = torch.randn(1, 300, 300, 128).cuda() # trace the module and replace codegen graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) From 1607d04e81530a3de96ce064b961c2b10ed7067a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 14 Nov 2022 16:02:47 +0800 Subject: [PATCH 013/209] add part of index tracer --- chunk_codegen.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/chunk_codegen.py b/chunk_codegen.py index c8bb433ef6b5..4b8882afc105 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -19,6 +19,123 @@ __all__ = ['python_code_with_activation_checkpoint'] +class NodeIndexTracer(object): + def __init__(self, gm) -> None: + self.gm = gm + self.nodes_list = list(gm.graph.nodes) + self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] + self.idx_trace_equal = [] + self.idx_count = 1 + + def add_index(self): + self.idx_count += 1 + return self.idx_count - 1 + + def inherit_computation(self, node_from, node_to): + _, compute_from = self.find_trace_from_node(node_from) + idx_to, compute_to = self.find_trace_from_node(node_to) + for i in compute_from: + if i in idx_to: + compute_to.append(i) + + def mark_idx_equal(self, idx1, idx2): + self.idx_trace_equal.append((idx1, idx2)) + + def mark_computation(self, node, idx, dim): + input_node_idx_trace = self.find_idx_trace_from_node(node) + if isinstance(dim, int): + dim = [dim] + for d in dim: + cur_idx = input_node_idx_trace[d] + self.idx_trace_list[idx]['compute'].append(cur_idx) + + def find_trace_from_node(self, node): + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict['idx'], node_dict['compute'] + + def find_idx_trace_from_node(self, node): + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx_trace = self.idx_trace_list[node_idx]['idx'] + return node_idx_trace + + def assign_index_as_input(self, node, node_idx): + input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) + input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + def assign_all_index(self, node, node_idx): + shape = node.meta['tensor_meta'].shape + new_trace = [] + for _ in shape: + new_trace.append(self.add_index()) + self.idx_trace_list[node_idx]['idx'] = new_trace + + def assign_transpose_index(self, node, node_idx): + tranpose_dim = node.args[1:] + input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] + new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] + + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + def assign_linear_index(self, node, node_idx): + input_node, weight, bias = node.args + input_node_idx_trace = self.find_idx_trace_from_node(input_node) + weight_idx_trace = self.find_idx_trace_from_node(weight) + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + new_idx_trace[-1] = weight_idx_trace[1] + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + self.inherit_computation(input_node, node) + self.mark_computation(node, node_idx, [-1]) + self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) + + if bias: + bias_idx_trace = self.find_idx_trace_from_node(bias) + self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + + def assign_layernorm_index(self, node, idx): + self.assign_index_as_input(node, idx) + self.mark_computation(node, idx, [-1, -2]) + + def trace_node_idx(self): + for idx, node in enumerate(self.nodes_list): + if node.op == 'placeholder': + self.assign_all_index(node, idx) + elif node.op == 'call_method': + if 'transpose' in node.name: + self.assign_transpose_index(node, idx) + elif 'view' in node.name: + pass + elif 'permute' in node.name: + pass + else: + raise NotImplementedError(node.name, "method not implemented yet!") + elif node.op == 'call_function': + if 'linear' in node.name: + self.assign_linear_index(node, idx) + elif 'getattr' in node.name: + continue # get attr like shape + elif 'getitem' in node.name: + continue # get item in list + else: + raise NotImplementedError(node.name, "function not implemented yet!") + elif node.op == 'call_module': + if 'layernorm' in node.name: + self.assign_layernorm_index(node, idx) + else: + raise NotImplementedError(node.name, "module not implemented yet!") + elif node.op == 'get_attr': + self.assign_all_index(node, idx) # get param + else: + raise NotImplementedError(node.op, "op not implemented yet!") + def _get_meta_node_size(x): x = x.meta['tensor_meta'] x = x.numel * torch.tensor([], dtype=x.dtype).element_size() @@ -557,6 +674,8 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node_list = list(nodes) _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) _estimate_inference_mem(meta_graph) + node_index_tracer = NodeIndexTracer(meta_graph) + node_index_tracer.trace_node_idx() # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): From c36dba07defa3069ba65d5aafc53d8292e78cf60 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 14 Nov 2022 23:38:05 +0800 Subject: [PATCH 014/209] finish basic index tracer --- chunk_codegen.py | 133 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 9 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 4b8882afc105..8477fe9a1702 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -25,6 +25,7 @@ def __init__(self, gm) -> None: self.nodes_list = list(gm.graph.nodes) self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] + self.idx_view_list = [] self.idx_count = 1 def add_index(self): @@ -35,7 +36,7 @@ def inherit_computation(self, node_from, node_to): _, compute_from = self.find_trace_from_node(node_from) idx_to, compute_to = self.find_trace_from_node(node_to) for i in compute_from: - if i in idx_to: + if i in idx_to and i not in compute_to: compute_to.append(i) def mark_idx_equal(self, idx1, idx2): @@ -47,7 +48,8 @@ def mark_computation(self, node, idx, dim): dim = [dim] for d in dim: cur_idx = input_node_idx_trace[d] - self.idx_trace_list[idx]['compute'].append(cur_idx) + if cur_idx not in self.idx_trace_list[idx]['compute']: + self.idx_trace_list[idx]['compute'].append(cur_idx) def find_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) @@ -56,8 +58,11 @@ def find_trace_from_node(self, node): def find_idx_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) - node_idx_trace = self.idx_trace_list[node_idx]['idx'] - return node_idx_trace + return self.idx_trace_list[node_idx]['idx'] + + def find_compute_trace_from_node(self, node): + node_idx = _find_idx_by_name(node.name, self.nodes_list) + return self.idx_trace_list[node_idx]['compute'] def assign_index_as_input(self, node, node_idx): input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) @@ -82,6 +87,18 @@ def assign_transpose_index(self, node, node_idx): new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self.inherit_computation(node.args[0], node) + + def assign_permute_index(self, node, node_idx): + permute_dim = node.args[1:] + input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + for idx, d in enumerate(permute_dim): + new_idx_trace[idx] = input_node_idx_trace[d] + + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self.inherit_computation(node.args[0], node) def assign_linear_index(self, node, node_idx): input_node, weight, bias = node.args @@ -100,10 +117,99 @@ def assign_linear_index(self, node, node_idx): bias_idx_trace = self.find_idx_trace_from_node(bias) self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + def assign_matmul_index(self, node, node_idx): + matmul_left, matmul_right = node.args + matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) + matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) + + assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) + new_idx_trace = copy.deepcopy(matmul_left_idx_trace) + new_idx_trace[-1] = matmul_right_idx_trace[-1] + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + self.inherit_computation(matmul_left, node) + self.inherit_computation(matmul_right, node) + self.mark_computation(node, node_idx, [-1]) + self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + def assign_layernorm_index(self, node, idx): self.assign_index_as_input(node, idx) + self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [-1, -2]) - + + def assign_elementwise_index(self, node, idx): + self.assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) not in (int, float): + self.inherit_computation(node_in, node) + + def assign_softmax_index(self, node, idx): + self.assign_index_as_input(node, idx) + self.mark_computation(node, idx, [node.kwargs['dim']]) + + def assign_view_reshape_index(self, node, node_idx): + # get data, turn into number + origin_node = node.args[0] + origin_shape = origin_node.meta['tensor_meta'].shape + target_shape = [] + for i in range(1, len(node.args)): + if isinstance(node.args[i], int): + target_shape.append(node.args[i]) + else: + target_shape.append(node.args[i].meta['fwd_out'][0]) + + # compute the value of -1 + if -1 in target_shape: + origin_product = 1 + for i in origin_shape: + origin_product *= i + target_product = -1 + for i in target_shape: + target_product *= i + shape_idx = target_shape.index(-1) + target_shape[shape_idx] = origin_product // target_product + + # determine changed dim + len_diff = len(origin_shape) - len(target_shape) + if len_diff == 1: + # dim merge + dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] + dim_to = [dim_equal.index(False)] + dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + elif len_diff == -1: + # dim expand + dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] + dim_from = [dim_equal.index(False)] + dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + else: + raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") + + # get new index + origin_trace = self.find_idx_trace_from_node(origin_node) + new_trace = copy.deepcopy(origin_trace) + dim_from.reverse() + for i in dim_from: + new_trace.pop(i) + for i in dim_to: + new_trace.insert(i, self.add_index()) + self.idx_trace_list[node_idx]['idx'] = new_trace + + # inherit computation + self.inherit_computation(origin_node, node) + compute_log = self.find_compute_trace_from_node(origin_node) + for i in dim_from: + if origin_trace[i] in compute_log: + for j in dim_to: + self.mark_computation(node, node_idx, [j]) + break + + # log view + view_dict = {"idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [new_trace[i] for i in dim_to], + "dim_to": dim_to} + self.idx_view_list.append(view_dict) + def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': @@ -111,15 +217,21 @@ def trace_node_idx(self): elif node.op == 'call_method': if 'transpose' in node.name: self.assign_transpose_index(node, idx) - elif 'view' in node.name: - pass elif 'permute' in node.name: - pass + self.assign_permute_index(node, idx) + elif 'view' in node.name or 'reshape' in node.name: + self.assign_view_reshape_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': if 'linear' in node.name: self.assign_linear_index(node, idx) + elif 'matmul' in node.name: + self.assign_matmul_index(node, idx) + elif 'softmax' in node.name: + self.assign_softmax_index(node, idx) + elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): + self.assign_elementwise_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -127,12 +239,14 @@ def trace_node_idx(self): else: raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == 'call_module': - if 'layernorm' in node.name: + if any(n in node.name for n in ['layernorm', 'norm']): self.assign_layernorm_index(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == 'get_attr': self.assign_all_index(node, idx) # get param + elif node.op == 'output': + continue else: raise NotImplementedError(node.op, "op not implemented yet!") @@ -297,6 +411,7 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod # node is an operation, calculate tmp, output node and delete node memory else: # forward memory + # TODO: permute will create a tmp copy if not contiguous act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory From 70a98b8f56e690b75039561a729c5b623d175512 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 14 Nov 2022 23:49:48 +0800 Subject: [PATCH 015/209] add doc string --- chunk_codegen.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 8477fe9a1702..aa9d7ecd861f 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -26,13 +26,28 @@ def __init__(self, gm) -> None: self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] self.idx_view_list = [] - self.idx_count = 1 + self.idx_count = -1 def add_index(self): + """ + Update the count and return it. To record the idx number. + + Returns: + idx_count: int + """ self.idx_count += 1 - return self.idx_count - 1 + return self.idx_count def inherit_computation(self, node_from, node_to): + """ + Inherit computed dim from node_from to node_to. + If a dim in node_from is marked as computed and exists in node_to, + still mark it as computed in node_to. + + Args: + node_from (node): node to be inherited + node_to (node): new node to inherit + """ _, compute_from = self.find_trace_from_node(node_from) idx_to, compute_to = self.find_trace_from_node(node_to) for i in compute_from: @@ -40,9 +55,24 @@ def inherit_computation(self, node_from, node_to): compute_to.append(i) def mark_idx_equal(self, idx1, idx2): + """ + Mark 2 index to be equal. + + Args: + idx1 (int): index count. + idx2 (int): index count. + """ self.idx_trace_equal.append((idx1, idx2)) def mark_computation(self, node, idx, dim): + """ + Mark some dims of node as computed. + + Args: + node (node) + idx (int): node index + dim (list or int): dims to be marked as computed + """ input_node_idx_trace = self.find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] @@ -52,15 +82,40 @@ def mark_computation(self, node, idx, dim): self.idx_trace_list[idx]['compute'].append(cur_idx) def find_trace_from_node(self, node): + """ + Find node idx and compute trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) node_dict = self.idx_trace_list[node_idx] return node_dict['idx'], node_dict['compute'] def find_idx_trace_from_node(self, node): + """ + Find node idx trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['idx'] def find_compute_trace_from_node(self, node): + """ + Find node compute trace by the node. + + Args: + node (node) + Returns: + compute (list): computed idx of the node. + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute'] From f379d1a94d5ffc7aa4a0c47ffc56cddbf99f4650 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 10:18:00 +0800 Subject: [PATCH 016/209] add doc str --- chunk_codegen.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/chunk_codegen.py b/chunk_codegen.py index aa9d7ecd861f..a14f7c134985 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -120,6 +120,13 @@ def find_compute_trace_from_node(self, node): return self.idx_trace_list[node_idx]['compute'] def assign_index_as_input(self, node, node_idx): + """ + Assign node's trace as its input node. + + Args: + node (node) + node_idx (int) + """ input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] @@ -127,6 +134,13 @@ def assign_index_as_input(self, node, node_idx): self.idx_trace_list[node_idx]['idx'] = new_idx_trace def assign_all_index(self, node, node_idx): + """ + Add new index for all node's dims. + + Args: + node (node) + node_idx (int) + """ shape = node.meta['tensor_meta'].shape new_trace = [] for _ in shape: @@ -134,6 +148,15 @@ def assign_all_index(self, node, node_idx): self.idx_trace_list[node_idx]['idx'] = new_trace def assign_transpose_index(self, node, node_idx): + """ + Assign index for transpose op. + 1. swap input's dim according to transpose args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ tranpose_dim = node.args[1:] input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) @@ -145,6 +168,15 @@ def assign_transpose_index(self, node, node_idx): self.inherit_computation(node.args[0], node) def assign_permute_index(self, node, node_idx): + """ + Assign index for permute op. + 1. swap input's dim according to permute args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ permute_dim = node.args[1:] input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) @@ -156,6 +188,16 @@ def assign_permute_index(self, node, node_idx): self.inherit_computation(node.args[0], node) def assign_linear_index(self, node, node_idx): + """ + Assign index for linear op. + 1. copy trace from input node and change last index accroding to weight + 2. mark equal for input node last index, weight first dim and bias dim. + 3. inherit input's computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ input_node, weight, bias = node.args input_node_idx_trace = self.find_idx_trace_from_node(input_node) weight_idx_trace = self.find_idx_trace_from_node(weight) @@ -173,6 +215,16 @@ def assign_linear_index(self, node, node_idx): self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) def assign_matmul_index(self, node, node_idx): + """ + Assign index for matmul op. + 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) + 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. + 3. inherit matmul_left and matmul_right computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ matmul_left, matmul_right = node.args matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) @@ -188,21 +240,63 @@ def assign_matmul_index(self, node, node_idx): self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) def assign_layernorm_index(self, node, idx): + """ + Assign index for layernorm op. + 1. assign index as input node + 2. inherit computation and mark last 2 dims as computed. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [-1, -2]) def assign_elementwise_index(self, node, idx): + """ + Assign index for element-wise op (eg. relu sigmoid add mul). + 1. assign index as input node + 2. inherit computation from all input nodes. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) for node_in in node.args: if type(node_in) not in (int, float): self.inherit_computation(node_in, node) def assign_softmax_index(self, node, idx): + """ + Assign index for softmax op. + 1. assign index as input node + 2. inherit computation and mark softmax dim as computed. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) + self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [node.kwargs['dim']]) def assign_view_reshape_index(self, node, node_idx): + """ + Assign index for view and reshape op. + 1. get origin shape and target shape by meta info. + 2. compute the real value of -1 in target shape. + 3. determine changed dim, and assgin index for generated dim. + 4. log changed dim and generated dim for restore + 5. look into view list to see whether the view is associated with other, + if so assgin equal dim according to previous view. + 6. inherit computation. + + Args: + node (node) + node_idx (int) + """ # get data, turn into number origin_node = node.args[0] origin_shape = origin_node.meta['tensor_meta'].shape @@ -305,6 +399,7 @@ def trace_node_idx(self): else: raise NotImplementedError(node.op, "op not implemented yet!") + def _get_meta_node_size(x): x = x.meta['tensor_meta'] x = x.numel * torch.tensor([], dtype=x.dtype).element_size() From 7e2bd1e42892a3021b9882fb0d08f18cfcbcfe86 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 10:36:02 +0800 Subject: [PATCH 017/209] polish code --- chunk_codegen.py | 258 ++--------------------------------------------- 1 file changed, 8 insertions(+), 250 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index a14f7c134985..9930a0570436 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -3,20 +3,11 @@ import copy from typing import List, Callable, Any, Tuple, Dict, Iterable -try: - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin - from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size - CODEGEN_AVAILABLE = True -except: - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - CODEGEN_AVAILABLE = False - -if CODEGEN_AVAILABLE: - __all__ = ['ChunkCodeGen'] -else: - __all__ = ['python_code_with_activation_checkpoint'] +from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name +from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size +CODEGEN_AVAILABLE = True +__all__ = ['ChunkCodeGen'] class NodeIndexTracer(object): @@ -289,9 +280,9 @@ def assign_view_reshape_index(self, node, node_idx): 2. compute the real value of -1 in target shape. 3. determine changed dim, and assgin index for generated dim. 4. log changed dim and generated dim for restore - 5. look into view list to see whether the view is associated with other, + 5. inherit computation. + 6. TODO: look into view list to see whether the view is associated with other, if so assgin equal dim according to previous view. - 6. inherit computation. Args: node (node) @@ -352,7 +343,7 @@ def assign_view_reshape_index(self, node, node_idx): self.mark_computation(node, node_idx, [j]) break - # log view + # log view, not used now view_dict = {"idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, "idx_to": [new_trace[i] for i in dim_to], @@ -680,239 +671,6 @@ def _find_idx_by_name(name, nodes_list): if node.name == name: return idx raise RuntimeError("name %s not found in node list" % name) - - -def _find_offload_regions(nodes: List[Node]): - """This function is to find the offload regions - In pofo algorithm, during annotation, we will annotate the offload region with the - list in the form of [idx, offload_input, offload_bar]. idx indicates the offload - region's index, offload_input is a bool type indicates whether we need to offload - the input, offload_bar is a bool type indicates whether we need to offload all the - intermediate x_bars of this region. - """ - offload_regions = [] - offload_labels = [] - start = -1 - end = -1 - current_region = None - - for idx, node in enumerate(nodes): - if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): - act_offload_label = node.activation_offload - - if current_region == None: - current_region = act_offload_label - start = idx - offload_labels.append(act_offload_label) - - if act_offload_label != current_region: - assert start != -1 - offload_regions.append((start, idx - 1)) - offload_labels.append(act_offload_label) - current_region = act_offload_label - start = idx - end = -1 - - else: - if current_region is not None: - end = idx - 1 - assert start != -1 and end != -1 - offload_regions.append((start, end)) - start = end = -1 - current_region = None - - else: - pass - - return offload_regions, offload_labels - - -def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: - """ - Generate the checkpoint function definition - """ - return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" - - -def _gen_ckpt_output(output_vars: List[str]) -> str: - """ - Generate the return statement for checkpoint region - """ - return f"return {', '.join(output_vars)}" - - -def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): - """ - Generate the checkpoint function call code text - """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' - - -def _end_of_ckpt(node: Node, check_idx: int) -> bool: - """Check if the node could end the ckpt region - - Args: - node (Node): torch.fx.Node - check_idx (int): the index of checkpoint level for - nested checkpoint - - Returns: - bool - """ - if hasattr(node, "activation_checkpoint"): - if isinstance(node.activation_checkpoint, list): - return node.activation_checkpoint[check_idx] == None - else: - return False - else: - return True - - -def _find_nested_ckpt_regions(nodes, check_idx=0): - """ - Find the nested checkpoint regions given a list of consecutive nodes. The outputs - will be list of tuples, each tuple is in the form of (start_index, end_index). - """ - ckpt_regions = [] - start = -1 - end = -1 - current_region = None - - for idx, node in enumerate(nodes): - if hasattr(node, 'activation_checkpoint'): - if isinstance(getattr(node, 'activation_checkpoint'), int): - act_ckpt_label = node.activation_checkpoint - else: - act_ckpt_label = node.activation_checkpoint[check_idx] - - # this activation checkpoint label is not set yet - # meaning this is the first node of the activation ckpt region - if current_region is None: - current_region = act_ckpt_label - start = idx - - # if activation checkpoint has changed - # we restart the tracking - # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] - if act_ckpt_label != current_region: - assert start != -1 - ckpt_regions.append((start, idx - 1)) - current_region = act_ckpt_label - start = idx - end = -1 - elif current_region is not None and _end_of_ckpt(node, check_idx): - # used to check the case below - # node ckpt states = [ckpt, ckpt, non-ckpt] - end = idx - 1 - assert start != -1 and end != -1 - ckpt_regions.append((start, end)) - start = end = -1 - current_region = None - else: - pass - - if current_region is not None: - end = len(nodes) - 1 - ckpt_regions.append((start, end)) - return ckpt_regions - - -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - level=0, - in_ckpt=False): - """Emit ckpt fuction in nested way - - Args: - body: forward code, in recursive calls, this part will be checkpoint - functions code - ckpt_func: checkpoint functions code, in recursive calls, this part - will be a buffer - node_list (List[Node]): list of torch.fx.Node - emit_node_func: function to emit a node - delete_unused_value_func: function to delete unused value - level (int, optional): checkpoint level. Defaults to 0. - in_ckpt (bool, optional): indicates wether the func is in recursive - call. Defaults to False. - """ - inputs, outputs = _find_input_and_output_nodes(node_list) - - # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].activation_checkpoint, int): - label = node_list[0].activation_checkpoint - ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') - for node in node_list: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) - usage += "\n" - body.append(usage) - - # use nested ckpt function codegen - else: - # label given by each layer, e.g. if you are currently at level [0, 1, 1] - # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) - ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') - - # if there is more level to fetch - if level + 1 < len(node_list[0].activation_checkpoint): - ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) - start_idx = [item[0] for item in ckpt_regions] - end_idx = [item[1] for item in ckpt_regions] - - # use ckpt_func_buffer to store nested checkpoint functions - ckpt_func_buffer = [] - node_idx = 0 - while 1: - if node_idx >= len(node_list): - break - - if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, - delete_unused_value_func, level + 1, True) - node_idx += len(ckpt_node_list) - - else: - node = node_list[node_idx] - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - node_idx += 1 - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - ckpt_func += ckpt_func_buffer - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' - if in_ckpt: - usage = ' ' + usage - body.append(usage) - - # last level - else: - for node in node_list: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' - if in_ckpt: - usage = ' ' + usage - body.append(usage) def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): From fad3b6d1a65ee04d18e4826045ce3af4e3d28f10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 10:46:51 +0800 Subject: [PATCH 018/209] polish code --- chunk_codegen.py | 478 +++++++++++++++++++++++------------------------ 1 file changed, 239 insertions(+), 239 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 9930a0570436..c1d9e26e790a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -10,6 +10,13 @@ __all__ = ['ChunkCodeGen'] +def _delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == 'placeholder': + user_to_last_uses[key].remove(n) + + class NodeIndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -19,7 +26,7 @@ def __init__(self, gm) -> None: self.idx_view_list = [] self.idx_count = -1 - def add_index(self): + def _add_index(self): """ Update the count and return it. To record the idx number. @@ -29,7 +36,7 @@ def add_index(self): self.idx_count += 1 return self.idx_count - def inherit_computation(self, node_from, node_to): + def _inherit_computation(self, node_from, node_to): """ Inherit computed dim from node_from to node_to. If a dim in node_from is marked as computed and exists in node_to, @@ -39,13 +46,13 @@ def inherit_computation(self, node_from, node_to): node_from (node): node to be inherited node_to (node): new node to inherit """ - _, compute_from = self.find_trace_from_node(node_from) - idx_to, compute_to = self.find_trace_from_node(node_to) + _, compute_from = self._find_trace_from_node(node_from) + idx_to, compute_to = self._find_trace_from_node(node_to) for i in compute_from: if i in idx_to and i not in compute_to: compute_to.append(i) - def mark_idx_equal(self, idx1, idx2): + def _mark_idx_equal(self, idx1, idx2): """ Mark 2 index to be equal. @@ -55,7 +62,7 @@ def mark_idx_equal(self, idx1, idx2): """ self.idx_trace_equal.append((idx1, idx2)) - def mark_computation(self, node, idx, dim): + def _mark_computation(self, node, idx, dim): """ Mark some dims of node as computed. @@ -64,7 +71,7 @@ def mark_computation(self, node, idx, dim): idx (int): node index dim (list or int): dims to be marked as computed """ - input_node_idx_trace = self.find_idx_trace_from_node(node) + input_node_idx_trace = self._find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] for d in dim: @@ -72,7 +79,7 @@ def mark_computation(self, node, idx, dim): if cur_idx not in self.idx_trace_list[idx]['compute']: self.idx_trace_list[idx]['compute'].append(cur_idx) - def find_trace_from_node(self, node): + def _find_trace_from_node(self, node): """ Find node idx and compute trace by the node. @@ -86,7 +93,7 @@ def find_trace_from_node(self, node): node_dict = self.idx_trace_list[node_idx] return node_dict['idx'], node_dict['compute'] - def find_idx_trace_from_node(self, node): + def _find_idx_trace_from_node(self, node): """ Find node idx trace by the node. @@ -98,7 +105,7 @@ def find_idx_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['idx'] - def find_compute_trace_from_node(self, node): + def _find_compute_trace_from_node(self, node): """ Find node compute trace by the node. @@ -110,7 +117,7 @@ def find_compute_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute'] - def assign_index_as_input(self, node, node_idx): + def _assign_index_as_input(self, node, node_idx): """ Assign node's trace as its input node. @@ -124,7 +131,7 @@ def assign_index_as_input(self, node, node_idx): new_idx_trace = copy.deepcopy(input_node_idx_trace) self.idx_trace_list[node_idx]['idx'] = new_idx_trace - def assign_all_index(self, node, node_idx): + def _assign_all_index(self, node, node_idx): """ Add new index for all node's dims. @@ -135,10 +142,10 @@ def assign_all_index(self, node, node_idx): shape = node.meta['tensor_meta'].shape new_trace = [] for _ in shape: - new_trace.append(self.add_index()) + new_trace.append(self._add_index()) self.idx_trace_list[node_idx]['idx'] = new_trace - def assign_transpose_index(self, node, node_idx): + def _assign_transpose_index(self, node, node_idx): """ Assign index for transpose op. 1. swap input's dim according to transpose args @@ -149,16 +156,16 @@ def assign_transpose_index(self, node, node_idx): node_idx (int) """ tranpose_dim = node.args[1:] - input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(node.args[0], node) + self._inherit_computation(node.args[0], node) - def assign_permute_index(self, node, node_idx): + def _assign_permute_index(self, node, node_idx): """ Assign index for permute op. 1. swap input's dim according to permute args @@ -169,16 +176,16 @@ def assign_permute_index(self, node, node_idx): node_idx (int) """ permute_dim = node.args[1:] - input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) new_idx_trace = copy.deepcopy(input_node_idx_trace) for idx, d in enumerate(permute_dim): new_idx_trace[idx] = input_node_idx_trace[d] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(node.args[0], node) + self._inherit_computation(node.args[0], node) - def assign_linear_index(self, node, node_idx): + def _assign_linear_index(self, node, node_idx): """ Assign index for linear op. 1. copy trace from input node and change last index accroding to weight @@ -190,22 +197,22 @@ def assign_linear_index(self, node, node_idx): node_idx (int) """ input_node, weight, bias = node.args - input_node_idx_trace = self.find_idx_trace_from_node(input_node) - weight_idx_trace = self.find_idx_trace_from_node(weight) + input_node_idx_trace = self._find_idx_trace_from_node(input_node) + weight_idx_trace = self._find_idx_trace_from_node(weight) new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace[-1] = weight_idx_trace[1] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(input_node, node) - self.mark_computation(node, node_idx, [-1]) - self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) + self._inherit_computation(input_node, node) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) if bias: - bias_idx_trace = self.find_idx_trace_from_node(bias) - self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + bias_idx_trace = self._find_idx_trace_from_node(bias) + self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) - def assign_matmul_index(self, node, node_idx): + def _assign_matmul_index(self, node, node_idx): """ Assign index for matmul op. 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) @@ -217,20 +224,20 @@ def assign_matmul_index(self, node, node_idx): node_idx (int) """ matmul_left, matmul_right = node.args - matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) - matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) + matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left) + matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right) assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) new_idx_trace = copy.deepcopy(matmul_left_idx_trace) new_idx_trace[-1] = matmul_right_idx_trace[-1] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(matmul_left, node) - self.inherit_computation(matmul_right, node) - self.mark_computation(node, node_idx, [-1]) - self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + self._inherit_computation(matmul_left, node) + self._inherit_computation(matmul_right, node) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) - def assign_layernorm_index(self, node, idx): + def _assign_layernorm_index(self, node, idx): """ Assign index for layernorm op. 1. assign index as input node @@ -240,11 +247,11 @@ def assign_layernorm_index(self, node, idx): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) - self.inherit_computation(node.args[0], node) - self.mark_computation(node, idx, [-1, -2]) + self._assign_index_as_input(node, idx) + self._inherit_computation(node.args[0], node) + self._mark_computation(node, idx, [-1, -2]) - def assign_elementwise_index(self, node, idx): + def _assign_elementwise_index(self, node, idx): """ Assign index for element-wise op (eg. relu sigmoid add mul). 1. assign index as input node @@ -254,12 +261,12 @@ def assign_elementwise_index(self, node, idx): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) + self._assign_index_as_input(node, idx) for node_in in node.args: if type(node_in) not in (int, float): - self.inherit_computation(node_in, node) + self._inherit_computation(node_in, node) - def assign_softmax_index(self, node, idx): + def _assign_softmax_index(self, node, idx): """ Assign index for softmax op. 1. assign index as input node @@ -269,11 +276,11 @@ def assign_softmax_index(self, node, idx): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) - self.inherit_computation(node.args[0], node) - self.mark_computation(node, idx, [node.kwargs['dim']]) + self._assign_index_as_input(node, idx) + self._inherit_computation(node.args[0], node) + self._mark_computation(node, idx, [node.kwargs['dim']]) - def assign_view_reshape_index(self, node, node_idx): + def _assign_view_reshape_index(self, node, node_idx): """ Assign index for view and reshape op. 1. get origin shape and target shape by meta info. @@ -325,22 +332,22 @@ def assign_view_reshape_index(self, node, node_idx): raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") # get new index - origin_trace = self.find_idx_trace_from_node(origin_node) + origin_trace = self._find_idx_trace_from_node(origin_node) new_trace = copy.deepcopy(origin_trace) dim_from.reverse() for i in dim_from: new_trace.pop(i) for i in dim_to: - new_trace.insert(i, self.add_index()) + new_trace.insert(i, self._add_index()) self.idx_trace_list[node_idx]['idx'] = new_trace # inherit computation - self.inherit_computation(origin_node, node) - compute_log = self.find_compute_trace_from_node(origin_node) + self._inherit_computation(origin_node, node) + compute_log = self._find_compute_trace_from_node(origin_node) for i in dim_from: if origin_trace[i] in compute_log: for j in dim_to: - self.mark_computation(node, node_idx, [j]) + self._mark_computation(node, node_idx, [j]) break # log view, not used now @@ -353,25 +360,25 @@ def assign_view_reshape_index(self, node, node_idx): def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': - self.assign_all_index(node, idx) + self._assign_all_index(node, idx) elif node.op == 'call_method': if 'transpose' in node.name: - self.assign_transpose_index(node, idx) + self._assign_transpose_index(node, idx) elif 'permute' in node.name: - self.assign_permute_index(node, idx) + self._assign_permute_index(node, idx) elif 'view' in node.name or 'reshape' in node.name: - self.assign_view_reshape_index(node, idx) + self._assign_view_reshape_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': if 'linear' in node.name: - self.assign_linear_index(node, idx) + self._assign_linear_index(node, idx) elif 'matmul' in node.name: - self.assign_matmul_index(node, idx) + self._assign_matmul_index(node, idx) elif 'softmax' in node.name: - self.assign_softmax_index(node, idx) + self._assign_softmax_index(node, idx) elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): - self.assign_elementwise_index(node, idx) + self._assign_elementwise_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -380,206 +387,198 @@ def trace_node_idx(self): raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == 'call_module': if any(n in node.name for n in ['layernorm', 'norm']): - self.assign_layernorm_index(node, idx) + self._assign_layernorm_index(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == 'get_attr': - self.assign_all_index(node, idx) # get param + self._assign_all_index(node, idx) # get param elif node.op == 'output': continue else: raise NotImplementedError(node.op, "op not implemented yet!") -def _get_meta_node_size(x): - x = x.meta['tensor_meta'] - x = x.numel * torch.tensor([], dtype=x.dtype).element_size() - return x - +class MemoryEstimator(object): + def __init__(self) -> None: + pass -def _get_output_node_size(n): - fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} - return activation_size(fwd_out) + def _get_meta_node_size(self, x): + x = x.meta['tensor_meta'] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x + def _get_output_node_size(self, n): + fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + return activation_size(fwd_out) -def _get_delete_node_size(user, user_to_last_uses): - if user.op in ('placeholder', 'output'): + def _get_delete_node_size(self, user, user_to_last_uses): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete]) + return delete_size return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete]) - return delete_size - return 0 - - -def _get_last_usr(nodes): - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - return user_to_last_uses - - -def _delete_free_var_from_last_use(user_to_last_uses): - for key, value in user_to_last_uses.items(): - for n in value: - if n.op == 'placeholder': - user_to_last_uses[key].remove(n) - - -def _get_contiguous_memory(node, not_contiguous_list, delete=False): - mem = 0 - not_contiguous_ops = ['transpose', 'permute'] - - if node.op == 'call_function' and 'matmul' in node.name: - for n in node.args: - if n in not_contiguous_list: - # matmul won't change origin tensor, but create a tmp copy - mem += _get_output_node_size(n) - elif node.op == 'call_module': - for n in node.args: - if n in not_contiguous_list: - # module will just make origin tensor to contiguous - if delete: - not_contiguous_list.remove(n) - elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - elif any(i in node.args for i in not_contiguous_list): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - - return mem - - -def _estimate_inference_mem(gm: torch.fx.GraphModule): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - for node in gm.graph.nodes: - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) / (1024 ** 2) - act_memory_peak_log.append(act_memory) - act_memory_after_node_log.append(act_memory) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) - act_memory += _get_output_node_size(node) / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - act_memory_after_node_log.append(act_memory) - print("no chunk") - _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") - _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") - - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory + def _get_last_usr(self, nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses + + def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ['transpose', 'permute'] + + if node.op == 'call_function' and 'matmul' in node.name: + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += self._get_output_node_size(n) + elif node.op == 'call_module': + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + elif any(i in node.args for i in not_contiguous_list): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + + return mem + + def estimate_inference_mem(self, gm: torch.fx.GraphModule): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) + for node in gm.graph.nodes: + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += self._get_meta_node_size(node) / (1024 ** 2) + act_memory_peak_log.append(act_memory) + act_memory_after_node_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) + act_memory += self._get_output_node_size(node) / (1024 ** 2) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) + act_memory_after_node_log.append(act_memory) + + print("no chunk") + self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") + self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") + + param_memory = parameter_size(gm) + return act_memory + param_memory, param_memory -def _get_chunk_ratio(node, chunk_dim, chunk_size): - shape = node.meta['tensor_meta'].shape - chunk_ratio = float(chunk_size) / shape[chunk_dim] - return chunk_ratio + def _get_chunk_ratio(self, node, chunk_dim, chunk_size): + shape = node.meta['tensor_meta'].shape + chunk_ratio = float(chunk_size) / shape[chunk_dim] + return chunk_ratio + + + def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + node_idx = _find_idx_by_name(n.name, node_list) + if start_node <= node_idx < end_node: + delete_size += self._get_output_node_size(n) * chunk_ratio + return delete_size -def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): - if user.op in ('placeholder', 'output'): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - delete_size = 0 - for n in nodes_to_delete: - node_idx = _find_idx_by_name(n.name, node_list) - if start_node <= node_idx < end_node: - delete_size += _get_output_node_size(n) * chunk_ratio - return delete_size - - -def _print_mem_log(log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end='') - if (idx + 1) % 3 == 0: - print("") - print("\n") - - -def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - within_chunk = False - region_idx = 0 - chunk_ratio = 1 # use it to estimate chunk mem - node_list = list(gm.graph.nodes) - - for idx, node in enumerate(node_list): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if idx in start_nodes: - within_chunk = True - chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) - - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2) - act_memory_peak_log.append(act_memory) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - # TODO: permute will create a tmp copy if not contiguous - act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) - act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) - if within_chunk: - act_memory -= _get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, - start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + def _print_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") + + + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) + within_chunk = False + region_idx = 0 + chunk_ratio = 1 # use it to estimate chunk mem + node_list = list(gm.graph.nodes) + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if idx in start_nodes: + within_chunk = True + chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) + act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) + + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2) + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory else: - act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + # forward memory + # TODO: permute will create a tmp copy if not contiguous + act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) + act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) + if within_chunk: + act_memory -= self._get_chunk_delete_node_size( + node, user_to_last_uses, chunk_ratio, node_list, + start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + else: + act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + + if idx in end_nodes: + act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + within_chunk = False + chunk_ratio = 1 + region_idx += 1 - if idx in end_nodes: - act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2) - within_chunk = False - chunk_ratio = 1 - region_idx += 1 - - act_memory_after_node_log.append(act_memory) + act_memory_after_node_log.append(act_memory) - print("chunk") - _print_mem_log(act_memory_peak_log, node_list, "peak") - _print_mem_log(act_memory_after_node_log, node_list, "after") + print("chunk") + self._print_mem_log(act_memory_peak_log, node_list, "peak") + self._print_mem_log(act_memory_after_node_log, node_list, "after") - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory + param_memory = parameter_size(gm) + return act_memory + param_memory, param_memory def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) - _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - _estimate_inference_mem(meta_graph) + memory_estimator = MemoryEstimator() + memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) + memory_estimator.estimate_inference_mem(meta_graph) node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx() From 54a34a7e46d2f9e0234eb9295f3507e720ba21b2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 11:30:43 +0800 Subject: [PATCH 019/209] update active log --- chunk_codegen.py | 56 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index c1d9e26e790a..ade986d1e343 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -407,18 +407,41 @@ def _get_meta_node_size(self, x): x = x.numel * torch.tensor([], dtype=x.dtype).element_size() return x - def _get_output_node_size(self, n): + def _get_output_node(self, n): fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} - return activation_size(fwd_out) + out_size = activation_size(fwd_out) + out_node = [n.name] if out_size > 0 else [] + return out_size, out_node + + def _get_output_node_size(self, n): + return self._get_output_node(n)[0] + + def _add_active_node(self, n, active_list): + new_active = self._get_output_node(n)[1] + for i in new_active: + if i not in active_list: + active_list.append(i) + def _get_delete_node(self, user, user_to_last_uses): + delete_size = 0 + delete_node = [] + if user.op not in ('placeholder', 'output'): + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + for i in range(len(out_node)): + if out_node[i][0] > 0: + delete_node.append(out_node[i][1][0]) + return delete_size, delete_node + def _get_delete_node_size(self, user, user_to_last_uses): - if user.op in ('placeholder', 'output'): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete]) - return delete_size - return 0 + return self._get_delete_node(user, user_to_last_uses)[0] + + def _remove_active_node(self, user, user_to_last_uses, active_list): + delete_node = self._get_delete_node(user, user_to_last_uses)[1] + for i in delete_node: + active_list.remove(i) def _get_last_usr(self, nodes): node_to_last_use: Dict[Node, Node] = {} @@ -438,7 +461,7 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): mem = 0 not_contiguous_ops = ['transpose', 'permute'] - if node.op == 'call_function' and 'matmul' in node.name: + if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']): for n in node.args: if n in not_contiguous_list: # matmul won't change origin tensor, but create a tmp copy @@ -463,6 +486,8 @@ def estimate_inference_mem(self, gm: torch.fx.GraphModule): act_memory_peak_log = [] act_memory_after_node_log = [] not_contiguous_list = [] + active_node_list = [] + active_node_list_log = [] user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) _delete_free_var_from_last_use(user_to_last_uses) for node in gm.graph.nodes: @@ -470,7 +495,7 @@ def estimate_inference_mem(self, gm: torch.fx.GraphModule): if node.op == 'placeholder': act_memory += self._get_meta_node_size(node) / (1024 ** 2) act_memory_peak_log.append(act_memory) - act_memory_after_node_log.append(act_memory) + active_node_list.append(node.name) # skip output elif node.op == 'output': continue @@ -484,8 +509,12 @@ def estimate_inference_mem(self, gm: torch.fx.GraphModule): # delete useless memory act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - act_memory_after_node_log.append(act_memory) + # log active node + self._add_active_node(node, active_node_list) + self._remove_active_node(node, user_to_last_uses, active_node_list) + act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) print("no chunk") self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") @@ -551,7 +580,6 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, en # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - # TODO: permute will create a tmp copy if not contiguous act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory @@ -694,9 +722,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + memory_estimator = MemoryEstimator() memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) memory_estimator.estimate_inference_mem(meta_graph) + node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx() From d9ca2f898d1fb2a2b76ba663ebb27b9a778bd0ed Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 15:50:50 +0800 Subject: [PATCH 020/209] polish code --- chunk_codegen.py | 87 +++++++++++++++--------------------------------- 1 file changed, 27 insertions(+), 60 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index ade986d1e343..77aca8deb81f 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -438,7 +438,7 @@ def _get_delete_node(self, user, user_to_last_uses): def _get_delete_node_size(self, user, user_to_last_uses): return self._get_delete_node(user, user_to_last_uses)[0] - def _remove_active_node(self, user, user_to_last_uses, active_list): + def _remove_deactive_node(self, user, user_to_last_uses, active_list): delete_node = self._get_delete_node(user, user_to_last_uses)[1] for i in delete_node: active_list.remove(i) @@ -481,48 +481,6 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): return mem - def estimate_inference_mem(self, gm: torch.fx.GraphModule): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - active_node_list = [] - active_node_list_log = [] - user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - for node in gm.graph.nodes: - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += self._get_meta_node_size(node) / (1024 ** 2) - act_memory_peak_log.append(act_memory) - active_node_list.append(node.name) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) - act_memory += self._get_output_node_size(node) / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - # log active node - self._add_active_node(node, active_node_list) - self._remove_active_node(node, user_to_last_uses, active_node_list) - - act_memory_after_node_log.append(act_memory) - active_node_list_log.append(copy.deepcopy(active_node_list)) - print("no chunk") - self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") - self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") - - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory - - def _get_chunk_ratio(self, node, chunk_dim, chunk_size): shape = node.meta['tensor_meta'].shape chunk_ratio = float(chunk_size) / shape[chunk_dim] @@ -550,25 +508,28 @@ def _print_mem_log(self, log, nodes, title=None): print("") print("\n") - - def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None): act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + active_node_list = [] + active_node_list_log = [] not_contiguous_list = [] + node_list = list(gm.graph.nodes) user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) _delete_free_var_from_last_use(user_to_last_uses) - within_chunk = False - region_idx = 0 + + use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]) + chunk_within = False + chunk_region_idx = 0 chunk_ratio = 1 # use it to estimate chunk mem - node_list = list(gm.graph.nodes) for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if idx in start_nodes: - within_chunk = True - chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) + if use_chunk and idx in start_nodes: + chunk_within = True + chunk_ratio = self._get_chunk_ratio(node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]) + act_memory += self._get_output_node_size(node_list[end_nodes[chunk_region_idx]]) / (1024 ** 2) # if node is placeholder, just add the size of the node if node.op == 'placeholder': @@ -586,22 +547,28 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, en act_memory_peak_log.append(act_memory) # delete useless memory act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) - if within_chunk: + if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses, chunk_ratio, node_list, - start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2) else: act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - - if idx in end_nodes: + + # log active node + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + + # if node in chunk end nodes, restore chunk settings + if use_chunk and idx in end_nodes: act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) - within_chunk = False + chunk_within = False chunk_ratio = 1 - region_idx += 1 + chunk_region_idx += 1 act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) - print("chunk") + print("with chunk" if use_chunk else "without chunk") self._print_mem_log(act_memory_peak_log, node_list, "peak") self._print_mem_log(act_memory_after_node_log, node_list, "after") @@ -725,7 +692,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v memory_estimator = MemoryEstimator() memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - memory_estimator.estimate_inference_mem(meta_graph) + memory_estimator.estimate_chunk_inference_mem(meta_graph) node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx() From 7330d907459a220ebedaeafbbcc7c3cff3c8b1c4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sun, 4 Dec 2022 17:05:28 +0800 Subject: [PATCH 021/209] add possible region search --- chunk_codegen.py | 116 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 77aca8deb81f..ba83f7fec3be 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -356,7 +356,17 @@ def _assign_view_reshape_index(self, node, node_idx): "idx_to": [new_trace[i] for i in dim_to], "dim_to": dim_to} self.idx_view_list.append(view_dict) - + + def _merge_equal_idx(self): + idx_equal = copy.deepcopy(self.idx_trace_equal) + idx_equal.reverse() + for idx in idx_equal: + merge_to = min(idx) + merge_from = max(idx) + for trace in self.idx_trace_list: + if merge_from in trace['idx']: + trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] + def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': @@ -396,6 +406,7 @@ def trace_node_idx(self): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + self._merge_equal_idx() class MemoryEstimator(object): @@ -433,6 +444,8 @@ def _get_delete_node(self, user, user_to_last_uses): for i in range(len(out_node)): if out_node[i][0] > 0: delete_node.append(out_node[i][1][0]) + elif nodes_to_delete[i].op == 'placeholder': + delete_node.append(nodes_to_delete[i].name) return delete_size, delete_node def _get_delete_node_size(self, user, user_to_last_uses): @@ -516,8 +529,9 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non active_node_list_log = [] not_contiguous_list = [] node_list = list(gm.graph.nodes) - user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + _delete_free_var_from_last_use(user_to_last_uses_no_free_var) use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]) chunk_within = False @@ -535,6 +549,7 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non if node.op == 'placeholder': act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2) act_memory_peak_log.append(act_memory) + active_node_list.append(node.name) # skip output elif node.op == 'output': continue @@ -549,10 +564,10 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) if chunk_within: act_memory -= self._get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, + node, user_to_last_uses_no_free_var, chunk_ratio, node_list, start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2) else: - act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2) # log active node self._add_active_node(node, active_node_list) @@ -572,8 +587,92 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non self._print_mem_log(act_memory_peak_log, node_list, "peak") self._print_mem_log(act_memory_after_node_log, node_list, "after") - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory + # param_memory = parameter_size(gm) + # all_memory = act_memory + param_memory + return act_memory_peak_log, act_memory_after_node_log, active_node_list_log + + +class ChunkRegionSearch(object): + def __init__(self, gm) -> None: + self.gm = gm + self.node_list = list(gm.graph.nodes) + self.memory_estimator = MemoryEstimator() + self.index_tracer = NodeIndexTracer(gm) + self.index_tracer.trace_node_idx() + + def _find_peak_node(self, mem_peak): + max_value = max(mem_peak) + max_idx = [mem_peak.index(max_value)] + return max_idx + + def _get_free_var(self): + free_var_idx = [] + for idx, n in enumerate(self.node_list): + if n.op == 'placeholder': + free_var_idx.append(idx) + return free_var_idx + + def _get_min_free_var(self, active_node_list, free_vars): + min_len = 999 + for idx, n in enumerate(active_node_list): + if idx in free_vars: + continue + if len(n) < min_len: + min_len = len(n) + return min_len + + def _search_max_chunk_region(self, active_node, peak_node): + free_vars = self._get_free_var() + min_var = self._get_min_free_var(active_node, free_vars) + + # from peak_node to free_var + chunk_region_start = None + for i in range(peak_node, -1, -1): + if len(active_node[i]) == min_var: + chunk_region_start = i + 1 + break + if i in free_vars or i == 0: + raise RuntimeError() + # from peak_node to len-2 + chunk_region_end = None + for i in range(peak_node, len(active_node) - 1): + if len(active_node[i]) == min_var: + chunk_region_end = i - 1 + break + if i in free_vars or i == 0: + raise RuntimeError() + return chunk_region_start, chunk_region_end + + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): + possible_chunk_region = [] + for before_idx in range(max_chunk_region[0], peak_node): + for after_idx in range(peak_node, max_chunk_region[1]): + # skip non compute nodes + if any(op in ['placeholder', 'get_attr', 'output'] for op in + [self.node_list[before_idx].op, self.node_list[after_idx].op]): + continue + if any(any(i in name for i in ['getitem', 'getattr']) for name in + [self.node_list[before_idx].name, self.node_list[after_idx].name]): + continue + + # select free dim + before_trace = self.index_tracer.idx_trace_list[before_idx] + after_trace = self.index_tracer.idx_trace_list[after_idx] + free_dim = [] + for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): + if (before_trace['idx'][i] == after_trace['idx'][i] and + before_trace['idx'][i] not in before_trace['compute'] and + after_trace['idx'][i] not in after_trace['compute']): + free_dim.append(i) + possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) + return possible_chunk_region + + def search_region(self): + mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + peak_nodes = self._find_peak_node(mem_peak) + for idx, peak_node in enumerate(peak_nodes): + max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -696,6 +795,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx() + + chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_region_search.search_region() # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): From 3b7d6712065b65d9c93feb64a488739e4483981f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 6 Dec 2022 11:08:39 +0800 Subject: [PATCH 022/209] finish region search loop --- chunk_codegen.py | 152 ++++++++++++++++++++++++++++++++----------- chunk_codegen_run.py | 4 +- 2 files changed, 116 insertions(+), 40 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index ba83f7fec3be..47cda0f8ed20 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -21,7 +21,7 @@ class NodeIndexTracer(object): def __init__(self, gm) -> None: self.gm = gm self.nodes_list = list(gm.graph.nodes) - self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] + self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 @@ -48,9 +48,12 @@ def _inherit_computation(self, node_from, node_to): """ _, compute_from = self._find_trace_from_node(node_from) idx_to, compute_to = self._find_trace_from_node(node_to) - for i in compute_from: - if i in idx_to and i not in compute_to: - compute_to.append(i) + for k, v in compute_from.items(): + if k in idx_to: + if k in compute_to: + compute_to[k].extend(v) + else: + compute_to[k] = copy.deepcopy(v) def _mark_idx_equal(self, idx1, idx2): """ @@ -77,7 +80,9 @@ def _mark_computation(self, node, idx, dim): for d in dim: cur_idx = input_node_idx_trace[d] if cur_idx not in self.idx_trace_list[idx]['compute']: - self.idx_trace_list[idx]['compute'].append(cur_idx) + self.idx_trace_list[idx]['compute'][cur_idx] = [idx] + else: + self.idx_trace_list[idx]['compute'][cur_idx].append(idx) def _find_trace_from_node(self, node): """ @@ -357,6 +362,11 @@ def _assign_view_reshape_index(self, node, node_idx): "dim_to": dim_to} self.idx_view_list.append(view_dict) + def _remove_duplicate_compute(self): + for i in self.idx_trace_list: + for k, v in i['compute'].items(): + i['compute'][k] = list(set(v)) + def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal.reverse() @@ -406,6 +416,8 @@ def trace_node_idx(self): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + + self._remove_duplicate_compute() self._merge_equal_idx() @@ -521,6 +533,19 @@ def _print_mem_log(self, log, nodes, title=None): print("") print("\n") + def _print_compute_op_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + if n.op in ['placeholder', 'get_attr', 'output']: + continue + if any(i in n.name for i in ['getitem', 'getattr']): + continue + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None): act_memory = 0.0 act_memory_peak_log = [] @@ -584,8 +609,10 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non active_node_list_log.append(copy.deepcopy(active_node_list)) print("with chunk" if use_chunk else "without chunk") - self._print_mem_log(act_memory_peak_log, node_list, "peak") - self._print_mem_log(act_memory_after_node_log, node_list, "after") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -602,7 +629,7 @@ def __init__(self, gm) -> None: def _find_peak_node(self, mem_peak): max_value = max(mem_peak) - max_idx = [mem_peak.index(max_value)] + max_idx = mem_peak.index(max_value) return max_idx def _get_free_var(self): @@ -635,18 +662,35 @@ def _search_max_chunk_region(self, active_node, peak_node): raise RuntimeError() # from peak_node to len-2 chunk_region_end = None - for i in range(peak_node, len(active_node) - 1): + for i in range(peak_node, len(active_node)): if len(active_node[i]) == min_var: - chunk_region_end = i - 1 + chunk_region_end = i break if i in free_vars or i == 0: raise RuntimeError() return chunk_region_start, chunk_region_end + def _not_compute(self, trace, chunk_range, dim_idx): + if trace['idx'][dim_idx] not in trace['compute']: + return True + if trace['idx'][dim_idx] in trace['compute'] and \ + all(i < chunk_range[0] or i > chunk_range[1] for i in trace['compute'][trace['idx'][dim_idx]]): + return True + return False + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] + output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + input_trace = [] + for i, n in enumerate(self.node_list): + if len(n.args) > 0 and n.op != 'output': + input_idx = _find_idx_by_name(n.args[0].name, self.node_list) + input_trace.append(output_trace[input_idx]) + else: + input_trace.append(None) + for before_idx in range(max_chunk_region[0], peak_node): - for after_idx in range(peak_node, max_chunk_region[1]): + for after_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if any(op in ['placeholder', 'get_attr', 'output'] for op in [self.node_list[before_idx].op, self.node_list[after_idx].op]): @@ -656,23 +700,59 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): continue # select free dim - before_trace = self.index_tracer.idx_trace_list[before_idx] - after_trace = self.index_tracer.idx_trace_list[after_idx] + before_trace = input_trace[before_idx] + after_trace = output_trace[after_idx] free_dim = [] for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): if (before_trace['idx'][i] == after_trace['idx'][i] and - before_trace['idx'][i] not in before_trace['compute'] and - after_trace['idx'][i] not in after_trace['compute']): + self._not_compute(before_trace, (before_idx, after_idx), i) and + self._not_compute(after_trace, (before_idx, after_idx), i) and + self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1): free_dim.append(i) possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) return possible_chunk_region + def _search_best_chunk_region(self, possible_chunk_regions): + max_region_range = 0 + best_regions = None + for i in possible_chunk_regions: + if i['region'][1] - i['region'][0] > max_region_range: + best_regions = i + max_region_range = i['region'][1] - i['region'][0] + return best_regions + + def _step_search(self, peak_node, active_node): + max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) + return best_chunk_region + + def _stop_search(self, init_mem_peak, mem_peak): + sorted_init_mem_peak = sorted(init_mem_peak) + if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: + return True + return False + def search_region(self): - mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) - peak_nodes = self._find_peak_node(mem_peak) - for idx, peak_node in enumerate(peak_nodes): - max_chunk_region = self._search_max_chunk_region(active_node, peak_node) - possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + chunk_regions = [] + init_mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + mem_peak = init_mem_peak + + while True: + peak_node = self._find_peak_node(mem_peak) + chunk_region = self._step_search(peak_node, active_node) + if chunk_region is None or len(chunk_region['dim']) == 0: + break + + chunk_regions.append(chunk_region) + mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem( + self.gm, [i['region'][0] for i in chunk_regions], + [i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions)) + + if self._stop_search(init_mem_peak, mem_peak): + break + + return chunk_regions def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -696,11 +776,12 @@ def _get_first_non_single_dim(shape): raise RuntimeError("can not get first non single dim for shape", shape) -def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): +def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2): if len(chunk_input_meta) == 1: node = chunk_input_meta[0] node_shape = node.meta['tensor_meta'].shape - chunk_dim = _get_first_non_single_dim(node_shape) + free_shape = [node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))] + chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) out_shape = str(list(chunk_output.meta['tensor_meta'].shape)) @@ -713,12 +794,13 @@ def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): return context -def _gen_loop_end(chunk_outputs, chunk_inputs, node_list): +def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim): chunk_inputs_name = chunk_inputs[0].name chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape - chunk_dim = _get_first_non_single_dim(chunk_output_shape) + free_shape = [chunk_output_shape[i] if i in chunk_dim else 1 for i in range(len(chunk_output_shape))] + chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) @@ -780,7 +862,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(58, 62)] + chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_search = chunk_region_search.search_region() + chunk_regions = [i['region'] for i in chunk_search] + chunk_dims = [i['dim'] for i in chunk_search] + chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -789,16 +875,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node_list = list(nodes) - memory_estimator = MemoryEstimator() - memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - memory_estimator.estimate_chunk_inference_mem(meta_graph) - - node_index_tracer = NodeIndexTracer(meta_graph) - node_index_tracer.trace_node_idx() - - chunk_region_search = ChunkRegionSearch(meta_graph) - chunk_region_search.search_region() - # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): offload_node_list = node_list[start:end + 1] @@ -824,13 +900,13 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # add for loop chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] - body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]])) + body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]], chunk_dims[region_idx])) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var if node_idx in chunk_starts: - body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)') + body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor') body[-1] = ' ' + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) @@ -840,7 +916,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v delete_unused_value_func(node, body, chunk_inputs_names) if node_idx in chunk_ends: - body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list)) + body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) within_chunk_region = False region_idx += 1 diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 39363a80abcb..88c734903392 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -45,8 +45,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output" - assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output" # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() From f24c418bb04a1e65eaa0f6cf8aada466deca2598 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 6 Dec 2022 16:29:07 +0800 Subject: [PATCH 023/209] finish chunk define --- chunk_codegen.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 47cda0f8ed20..6740cd44ab6a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for node in nodes: for input_node in node._input_nodes.keys(): node_repr = repr(input_node) - if input_node not in nodes and node_repr not in input_nodes: + if input_node not in nodes and input_node not in input_nodes: input_nodes.append(input_node) # if a node has a user node which is not in the node list @@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for node in nodes: for output_node in node.users.keys(): node_repr = repr(node) - if output_node not in nodes and node_repr not in output_nodes: + if output_node not in nodes and output_node not in output_nodes: output_nodes.append(output_node) return input_nodes, output_nodes @@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list): raise RuntimeError("name %s not found in node list" % name) +def _replace_name(context, name_from, name_to): + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")] + for p in patterns: + source = p[0] + name_from + p[1] + target = p[0] + name_to + p[1] + if source in context: + context = context.replace(source, target) + return context + + def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use @@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - if node_idx in chunk_starts: - body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor') + body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor') body[-1] = ' ' + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) From a9d64377bb237f34fdafaeec2abcfdfb6e080091 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 6 Dec 2022 17:34:24 +0800 Subject: [PATCH 024/209] support new op --- chunk_codegen.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 6740cd44ab6a..2dc44d381d85 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -200,8 +200,12 @@ def _assign_linear_index(self, node, node_idx): Args: node (node) node_idx (int) - """ - input_node, weight, bias = node.args + """ + if len(node.args) == 2: + input_node, weight = node.args + bias = None + else: + input_node, weight, bias = node.args input_node_idx_trace = self._find_idx_trace_from_node(input_node) weight_idx_trace = self._find_idx_trace_from_node(weight) @@ -284,6 +288,53 @@ def _assign_softmax_index(self, node, idx): self._assign_index_as_input(node, idx) self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [node.kwargs['dim']]) + + def _assign_unsqueeze_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + self._inherit_computation(node.args[0], node) + self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) + + def _assign_dropout_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + + + def _assign_ones_like_index(self, node, node_idx): + """ + Assign index for oneslike op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_all_index(node, node_idx) + + def _assign_to_index(self, node, node_idx): + """ + Assign index for to op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) def _assign_view_reshape_index(self, node, node_idx): """ @@ -388,6 +439,10 @@ def trace_node_idx(self): self._assign_permute_index(node, idx) elif 'view' in node.name or 'reshape' in node.name: self._assign_view_reshape_index(node, idx) + elif 'unsqueeze' in node.name: + self._assign_unsqueeze_index(node, idx) + elif 'to' in node.name: + self._assign_to_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': @@ -399,6 +454,10 @@ def trace_node_idx(self): self._assign_softmax_index(node, idx) elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): self._assign_elementwise_index(node, idx) + elif 'ones_like' in node.name: + self._assign_ones_like_index(node, idx) + elif 'dropout' in node.name: + self._assign_dropout_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: From 6d99994a7afbfe290bcd798804b4e1e7e76d1281 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 6 Dec 2022 17:35:27 +0800 Subject: [PATCH 025/209] rename index tracer --- chunk_codegen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 2dc44d381d85..0f97f94a9d21 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -17,7 +17,7 @@ def _delete_free_var_from_last_use(user_to_last_uses): user_to_last_uses[key].remove(n) -class NodeIndexTracer(object): +class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm self.nodes_list = list(gm.graph.nodes) @@ -683,7 +683,7 @@ def __init__(self, gm) -> None: self.gm = gm self.node_list = list(gm.graph.nodes) self.memory_estimator = MemoryEstimator() - self.index_tracer = NodeIndexTracer(gm) + self.index_tracer = IndexTracer(gm) self.index_tracer.trace_node_idx() def _find_peak_node(self, mem_peak): From 2b4ebcc27839b34c015c4fb79e69abd721b83ee6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 8 Dec 2022 15:16:10 +0800 Subject: [PATCH 026/209] finishi codegen on msa --- chunk_codegen.py | 212 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 188 insertions(+), 24 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 0f97f94a9d21..1e8305ba395b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses): user_to_last_uses[key].remove(n) +class FlowTracer(object): + def __init__(self, gm) -> None: + self.gm = gm + self.nodes_list = list(gm.graph.nodes) + self.flow_trace = {} + + def _add_trace(self, name): + self.flow_trace[name] = [] + + def _add_node(self, trace_name, node): + self.flow_trace[trace_name].append({'node': node, 'inside_depend': [], 'outside_depend': []}) + + def _add_inside_depend(self, flow_name, node, inside_depend_node): + for i in self.flow_trace[flow_name]: + if i['node'] == node: + i['inside_depend'].append(inside_depend_node) + return + raise RuntimeError("node not found") + + def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depend_trace): + for i in self.flow_trace[flow_name]: + if i['node'] == node: + i['outside_depend'].append({outside_depend_trace: outside_depend_node}) + return + raise RuntimeError("node not found") + + def _init_trace(self): + for i in self.nodes_list: + if i.op == 'placeholder': + self._add_trace(i.name) + self._add_node(i.name, i) + + def _is_non_compute_node(self, node): + if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + def _is_non_compute_node_except_placeholder(self, node): + if any(i in node.op for i in ['get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + def _find_flow_for_node(self, node): + if type(self.nodes_list[0]) != type(node): + return None + if self._is_non_compute_node_except_placeholder(node): + return None + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i['node']: + return name + if any(i in node.name for i in ["ones_like"]): + self._add_trace(node.name) + self._add_node(node.name, node) + return node.name + raise RuntimeError("node not found") + + def _find_first_valid_flow(self, flow): + for i in flow: + if i is not None: + return i + raise RuntimeError("invalid flow") + + def find_node_flow(self, node): + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i['node']: + return name, i + raise RuntimeError("invalid node") + + def get_flow_mix(self, node): + if self._is_non_compute_node(node): + return None + _, node_trace = self.find_node_flow(node) + if len(node_trace['outside_depend']) == 0: + return None + elif len(node_trace['outside_depend']) > 1: + raise NotImplementedError + vars = list(node_trace['outside_depend'][0].values())[0] + return vars + + def get_same_flow_node(self, node_list, node): + name, _ = self.find_node_flow(node) + result = [] + for i in self.flow_trace[name]: + if i['node'] in node_list: + result.append(i['node']) + return result + + def trace_flow(self): + # init trace + self._init_trace() + + for node in self.nodes_list: + # skip if non compute node + if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \ + or self._is_non_compute_node(node): + continue + + node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] + + node_domin_flow = self._find_first_valid_flow(node_input_flows) + self._add_node(node_domin_flow, node) + for node_input_flow, arg in zip(node_input_flows, node.args): + if node_input_flow is None: + continue + elif node_input_flow == node_domin_flow: + self._add_inside_depend(node_domin_flow, node, arg) + else: + self._add_outside_depend(node_domin_flow, node, arg, node_input_flow) + return self.flow_trace + + class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -428,7 +543,7 @@ def _merge_equal_idx(self): if merge_from in trace['idx']: trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] - def trace_node_idx(self): + def trace_index(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': self._assign_all_index(node, idx) @@ -684,7 +799,9 @@ def __init__(self, gm) -> None: self.node_list = list(gm.graph.nodes) self.memory_estimator = MemoryEstimator() self.index_tracer = IndexTracer(gm) - self.index_tracer.trace_node_idx() + self.index_tracer.trace_index() + self.flow_tracer = FlowTracer(gm) + self.flow_tracer.trace_flow() def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -729,7 +846,7 @@ def _search_max_chunk_region(self, active_node, peak_node): raise RuntimeError() return chunk_region_start, chunk_region_end - def _not_compute(self, trace, chunk_range, dim_idx): + def _is_not_compute(self, trace, chunk_range, dim_idx): if trace['idx'][dim_idx] not in trace['compute']: return True if trace['idx'][dim_idx] in trace['compute'] and \ @@ -737,6 +854,56 @@ def _not_compute(self, trace, chunk_range, dim_idx): return True return False + def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx): + inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) + chunk_info = {'inputs': inputs, 'outputs': outputs} + flow_flag = False + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_var = self.flow_tracer.get_flow_mix(node) + if mix_flow_var is None: + continue + + # if there is a flow mix, op must be in [mul, add, div, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ['mul', 'add']): + for i in node.args: + if type(i) == type(mix_flow_var) and i != mix_flow_var: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO need to move that flow out of the chunk + if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: + flow_flag = True + for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var): + chunk_info['inputs'].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_flag = False + break + else: + raise NotImplementedError("%s not implemented" % node.name) + return flow_flag, chunk_info + + def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + before_trace = input_trace[start_idx] + after_trace = output_trace[end_idx] + free_dim = [] + chunk_infos = [] + for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): + if not (before_trace['idx'][i] == after_trace['idx'][i] and + self._is_not_compute(before_trace, (start_idx, end_idx), i) and + self._is_not_compute(after_trace, (start_idx, end_idx), i) and + self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): + continue + flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) + if flow_flag == None: + continue + chunk_infos.append(chunk_info) + free_dim.append(i) + return free_dim, chunk_infos + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) @@ -748,27 +915,22 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): else: input_trace.append(None) - for before_idx in range(max_chunk_region[0], peak_node): - for after_idx in range(peak_node, max_chunk_region[1] + 1): + for start_idx in range(max_chunk_region[0], peak_node): + for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if any(op in ['placeholder', 'get_attr', 'output'] for op in - [self.node_list[before_idx].op, self.node_list[after_idx].op]): + [self.node_list[start_idx].op, self.node_list[end_idx].op]): continue if any(any(i in name for i in ['getitem', 'getattr']) for name in - [self.node_list[before_idx].name, self.node_list[after_idx].name]): + [self.node_list[start_idx].name, self.node_list[end_idx].name]): continue # select free dim - before_trace = input_trace[before_idx] - after_trace = output_trace[after_idx] - free_dim = [] - for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): - if (before_trace['idx'][i] == after_trace['idx'][i] and - self._not_compute(before_trace, (before_idx, after_idx), i) and - self._not_compute(after_trace, (before_idx, after_idx), i) and - self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1): - free_dim.append(i) - possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) + free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) + if len(free_dim) > 0: + free_dim = [free_dim[0]] + chunk_info = [chunk_info[0]] + possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v chunk_search = chunk_region_search.search_region() chunk_regions = [i['region'] for i in chunk_search] chunk_dims = [i['dim'] for i in chunk_search] + chunk_infos = [i['chunk_info'] for i in chunk_search] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] - chunk_inputs = [] - chunk_outputs = [] + chunk_inputs = [[j['inputs'][0] for j in i] for i in chunk_infos] + chunk_outputs = [[j['outputs'][0] for j in i] for i in chunk_infos] within_chunk_region = False node_list = list(nodes) # find the input and output var names for each offload region - for idx, (start, end) in enumerate(chunk_regions): - offload_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(offload_node_list) - chunk_inputs.append(inputs) - chunk_outputs.append(outputs) + # for idx, (start, end) in enumerate(chunk_regions): + # offload_node_list = node_list[start:end + 1] + # inputs, outputs = _find_input_and_output_nodes(offload_node_list) + # chunk_inputs.append(inputs) + # chunk_outputs.append(outputs) + chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] chunk_inputs_names = [] From 979e61db92a95b8bc2904c5b38264f24060be310 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 9 Dec 2022 17:39:02 +0800 Subject: [PATCH 027/209] redesign index tracer, add source and change compute --- chunk_codegen.py | 310 +++++++++++++++++++++++++++++++---------------- 1 file changed, 206 insertions(+), 104 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1e8305ba395b..ce7d849178d1 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -16,6 +16,11 @@ def _delete_free_var_from_last_use(user_to_last_uses): if n.op == 'placeholder': user_to_last_uses[key].remove(n) +def _get_node_shape(node): + if hasattr(node.meta['tensor_meta'], "shape"): + return node.meta['tensor_meta'].shape + return None + class FlowTracer(object): def __init__(self, gm) -> None: @@ -136,11 +141,25 @@ class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm self.nodes_list = list(gm.graph.nodes) - self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))] + self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 + def _init_idx_trace_list(self): + idx_trace_list = [] + for n in self.nodes_list: + if _get_node_shape(n) != None: + cur_trace = { + 'idx': [None for _ in range(len(_get_node_shape(n)))], + 'compute': [[] for _ in range(len(_get_node_shape(n)))], + 'source': [[] for _ in range(len(_get_node_shape(n)))], + } + else: + cur_trace = {'idx': [], 'compute': [], 'source': []} + idx_trace_list.append(cur_trace) + return idx_trace_list + def _add_index(self): """ Update the count and return it. To record the idx number. @@ -150,35 +169,81 @@ def _add_index(self): """ self.idx_count += 1 return self.idx_count - - def _inherit_computation(self, node_from, node_to): - """ - Inherit computed dim from node_from to node_to. - If a dim in node_from is marked as computed and exists in node_to, - still mark it as computed in node_to. - - Args: - node_from (node): node to be inherited - node_to (node): new node to inherit - """ - _, compute_from = self._find_trace_from_node(node_from) - idx_to, compute_to = self._find_trace_from_node(node_to) - for k, v in compute_from.items(): - if k in idx_to: - if k in compute_to: - compute_to[k].extend(v) - else: - compute_to[k] = copy.deepcopy(v) - def _mark_idx_equal(self, idx1, idx2): + def _del_dim(self, idx, dim_idx): + self.idx_trace_list[idx]['idx'].pop(dim_idx) + self.idx_trace_list[idx]['compute'].pop(dim_idx) + self.idx_trace_list[idx]['source'].pop(dim_idx) + + def _add_dim(self, idx, dim_idx): + self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index()) + self.idx_trace_list[idx]['compute'].insert(dim_idx, []) + self.idx_trace_list[idx]['source'].insert(dim_idx, []) + + def _transform_index(self, node, node_dim): + node_idx = self._find_idx_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim] + node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim]) + node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_to_trace['source'][node_to_dim] = [] + node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) + node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + + def _inherit_all_computation(self, node_from, node_to): + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + assert len(node_from_compute) == len(node_to_compute) + for i in range(len(node_from_compute)): + self._add_source(node_from, i, node_to, i) + node_to_compute[i] = copy.deepcopy(node_from_compute[i]) + + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_to_trace = self._find_trace_from_node(node_to) + node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) + node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + + def _mark_computation_from_node(self, node_from, node_to, exclude=None): + if exclude == None: + exclude = [] + else: + exclude = [self._transform_index(node_to, i) for i in exclude] + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + # assert len(node_from_compute) == len(node_to_compute) + for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): + if self._transform_index(node_to, i) in exclude: + continue + self._add_source(node_from, i, node_to, i) + for j in node_from_compute[i]: + if j not in node_to_compute[i]: + node_to_compute[i].append(j) + + def _mark_idx_equal(self, node1, dim1, node2, dim2): """ Mark 2 index to be equal. Args: idx1 (int): index count. idx2 (int): index count. - """ - self.idx_trace_equal.append((idx1, idx2)) + """ + # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) + # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) + # if node1_idx > node2_idx: + # self._add_source(node2, dim2, node1, dim1) + # else: + # self._add_source(node1, dim1, node2, dim2) def _mark_computation(self, node, idx, dim): """ @@ -189,16 +254,14 @@ def _mark_computation(self, node, idx, dim): idx (int): node index dim (list or int): dims to be marked as computed """ - input_node_idx_trace = self._find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] + dims = list(range(len(_get_node_shape(node)))) for d in dim: - cur_idx = input_node_idx_trace[d] - if cur_idx not in self.idx_trace_list[idx]['compute']: - self.idx_trace_list[idx]['compute'][cur_idx] = [idx] - else: - self.idx_trace_list[idx]['compute'][cur_idx].append(idx) - + cur_dim = dims[d] + if idx not in self.idx_trace_list[idx]['compute'][cur_dim]: + self.idx_trace_list[idx]['compute'][cur_dim].append(idx) + def _find_trace_from_node(self, node): """ Find node idx and compute trace by the node. @@ -211,7 +274,7 @@ def _find_trace_from_node(self, node): """ node_idx = _find_idx_by_name(node.name, self.nodes_list) node_dict = self.idx_trace_list[node_idx] - return node_dict['idx'], node_dict['compute'] + return node_dict def _find_idx_trace_from_node(self, node): """ @@ -237,19 +300,23 @@ def _find_compute_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute'] - def _assign_index_as_input(self, node, node_idx): + def _assign_index_as_input(self, node, node_idx, input_node=None): """ Assign node's trace as its input node. Args: node (node) node_idx (int) - """ - input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) + """ + if input_node == None: + input_node = node.args[0] + input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] new_idx_trace = copy.deepcopy(input_node_idx_trace) self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + self._inherit_all_computation(input_node, node) def _assign_all_index(self, node, node_idx): """ @@ -275,15 +342,12 @@ def _assign_transpose_index(self, node, node_idx): node (node) node_idx (int) """ + input_node = node.args[0] tranpose_dim = node.args[1:] - input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) - new_idx_trace = copy.deepcopy(input_node_idx_trace) - new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] - new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] - - self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self._inherit_computation(node.args[0], node) + self._assign_index_as_input(node, node_idx, input_node) + self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) + self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) def _assign_permute_index(self, node, node_idx): """ @@ -296,14 +360,11 @@ def _assign_permute_index(self, node, node_idx): node_idx (int) """ permute_dim = node.args[1:] - input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) + input_node = node.args[0] - new_idx_trace = copy.deepcopy(input_node_idx_trace) + self._assign_index_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): - new_idx_trace[idx] = input_node_idx_trace[d] - - self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self._inherit_computation(node.args[0], node) + self._inherit_index(input_node, d, node, idx) def _assign_linear_index(self, node, node_idx): """ @@ -321,20 +382,15 @@ def _assign_linear_index(self, node, node_idx): bias = None else: input_node, weight, bias = node.args - input_node_idx_trace = self._find_idx_trace_from_node(input_node) - weight_idx_trace = self._find_idx_trace_from_node(weight) - new_idx_trace = copy.deepcopy(input_node_idx_trace) - new_idx_trace[-1] = weight_idx_trace[1] - self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self._assign_index_as_input(node, node_idx) + self._inherit_index(weight, 1, node, -1) - self._inherit_computation(input_node, node) self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) + self._mark_idx_equal(input_node, -1, weight, 0) if bias: - bias_idx_trace = self._find_idx_trace_from_node(bias) - self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + self._mark_idx_equal(input_node, -1, bias, 0) def _assign_matmul_index(self, node, node_idx): """ @@ -348,18 +404,14 @@ def _assign_matmul_index(self, node, node_idx): node_idx (int) """ matmul_left, matmul_right = node.args - matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left) - matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right) - assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) - new_idx_trace = copy.deepcopy(matmul_left_idx_trace) - new_idx_trace[-1] = matmul_right_idx_trace[-1] - self.idx_trace_list[node_idx]['idx'] = new_idx_trace + assert(len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right))) + self._assign_index_as_input(node, node_idx, matmul_left) + self._inherit_index(matmul_right, -1, node, -1) - self._inherit_computation(matmul_left, node) - self._inherit_computation(matmul_right, node) + self._mark_computation_from_node(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + self._mark_idx_equal(matmul_left, -1, matmul_right, -2) def _assign_layernorm_index(self, node, idx): """ @@ -372,7 +424,6 @@ def _assign_layernorm_index(self, node, idx): node_idx (int) """ self._assign_index_as_input(node, idx) - self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [-1, -2]) def _assign_elementwise_index(self, node, idx): @@ -386,9 +437,59 @@ def _assign_elementwise_index(self, node, idx): node_idx (int) """ self._assign_index_as_input(node, idx) + nodes_in = [] for node_in in node.args: - if type(node_in) not in (int, float): - self._inherit_computation(node_in, node) + if type(node_in) == type(node): + nodes_in.append(node_in) + self._mark_computation_from_node(node_in, node) + assert len(nodes_in) <= 2 + if len(nodes_in) == 2: + node_in0_shape = _get_node_shape(nodes_in[0]) + node_in1_shape = _get_node_shape(nodes_in[1]) + for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): + if node_in0_shape[i] == node_in1_shape[i]: + self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) + + def _assgin_no_change_index(self, node, idx): + self._assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) == type(node): + self._mark_computation_from_node(node_in, node) + + def _assign_einsum_index(self, node, idx): + """ + Assign index for einsum op. + + Args: + node (node) + node_idx (int) + """ + patterns = node.args[0] + input_nodes = node.args[1:] + + patterns = patterns.replace(" ", "") + left, right = patterns.split("->") + left = left.split(",") + + all_index = [] + for i in left: + for c in i: + all_index.append(c) + all_index = set(all_index) + free_index = set([i for i in right]) + sum_index = all_index - free_index + + for right_idx, right_indice in enumerate(right): + for left_idx, left_str in enumerate(left): + if right_indice in left_str: + source_idx = left_str.index(right_indice) + self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx) + + for i in sum_index: + for left_idx, left_str in enumerate(left): + if i in left_str: + self._mark_computation(node, idx, left_str.index(i)) + break def _assign_softmax_index(self, node, idx): """ @@ -401,7 +502,6 @@ def _assign_softmax_index(self, node, idx): node_idx (int) """ self._assign_index_as_input(node, idx) - self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [node.kwargs['dim']]) def _assign_unsqueeze_index(self, node, node_idx): @@ -412,10 +512,12 @@ def _assign_unsqueeze_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ + self._del_dim(node_idx, -1) self._assign_index_as_input(node, node_idx) - self._inherit_computation(node.args[0], node) self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) + self.idx_trace_list[node_idx]['compute'].insert(node.args[1], []) + self.idx_trace_list[node_idx]['source'].insert(node.args[1], []) def _assign_dropout_index(self, node, node_idx): """ @@ -427,7 +529,6 @@ def _assign_dropout_index(self, node, node_idx): node_idx (int) """ self._assign_index_as_input(node, node_idx) - def _assign_ones_like_index(self, node, node_idx): """ @@ -439,17 +540,6 @@ def _assign_ones_like_index(self, node, node_idx): node_idx (int) """ self._assign_all_index(node, node_idx) - - def _assign_to_index(self, node, node_idx): - """ - Assign index for to op. - 1. assign new index for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, node_idx) def _assign_view_reshape_index(self, node, node_idx): """ @@ -494,26 +584,26 @@ def _assign_view_reshape_index(self, node, node_idx): dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] dim_to = [dim_equal.index(False)] dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + self._add_dim(node_idx, -1) elif len_diff == -1: # dim expand dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] dim_from = [dim_equal.index(False)] dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + self._del_dim(node_idx, -1) else: raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") # get new index origin_trace = self._find_idx_trace_from_node(origin_node) - new_trace = copy.deepcopy(origin_trace) + self._assign_index_as_input(node, node_idx, origin_node) dim_from.reverse() for i in dim_from: - new_trace.pop(i) + self._del_dim(node_idx, i) for i in dim_to: - new_trace.insert(i, self._add_index()) - self.idx_trace_list[node_idx]['idx'] = new_trace + self._add_dim(node_idx, i) # inherit computation - self._inherit_computation(origin_node, node) compute_log = self._find_compute_trace_from_node(origin_node) for i in dim_from: if origin_trace[i] in compute_log: @@ -524,15 +614,10 @@ def _assign_view_reshape_index(self, node, node_idx): # log view, not used now view_dict = {"idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, - "idx_to": [new_trace[i] for i in dim_to], + "idx_to": [self.idx_trace_list[node_idx]['idx'][i] for i in dim_to], "dim_to": dim_to} self.idx_view_list.append(view_dict) - - def _remove_duplicate_compute(self): - for i in self.idx_trace_list: - for k, v in i['compute'].items(): - i['compute'][k] = list(set(v)) - + def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal.reverse() @@ -556,8 +641,8 @@ def trace_index(self): self._assign_view_reshape_index(node, idx) elif 'unsqueeze' in node.name: self._assign_unsqueeze_index(node, idx) - elif 'to' in node.name: - self._assign_to_index(node, idx) + elif any(i in node.name for i in ['to', 'contiguous']): + self._assgin_no_change_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': @@ -573,6 +658,8 @@ def trace_index(self): self._assign_ones_like_index(node, idx) elif 'dropout' in node.name: self._assign_dropout_index(node, idx) + elif 'einsum' in node.name: + self._assign_einsum_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -590,10 +677,20 @@ def trace_index(self): continue else: raise NotImplementedError(node.op, "op not implemented yet!") - - self._remove_duplicate_compute() - self._merge_equal_idx() - + # self._merge_equal_idx() + + def check_index(self, trace_idx, start_idx, end_idx): + for i in range(start_idx, end_idx + 1): + cur_idx = self.idx_trace_list[i]['idx'] + cur_compute = self.idx_trace_list[i]['compute'] + if trace_idx in cur_compute: + for j in cur_compute[trace_idx]: + if j < start_idx or j > end_idx: + return False + # same_idx = [1 if j == trace_idx else 0 for j in cur_idx] + # if sum(same_idx) > 1: + # return False + return True class MemoryEstimator(object): def __init__(self) -> None: @@ -897,6 +994,8 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): self._is_not_compute(after_trace, (start_idx, end_idx), i) and self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): continue + if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx): + continue flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) if flow_flag == None: continue @@ -910,7 +1009,10 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): input_trace = [] for i, n in enumerate(self.node_list): if len(n.args) > 0 and n.op != 'output': - input_idx = _find_idx_by_name(n.args[0].name, self.node_list) + if isinstance(n.args[0], str): + input_idx = _find_idx_by_name(n.args[1].name, self.node_list) + else: + input_idx = _find_idx_by_name(n.args[0].name, self.node_list) input_trace.append(output_trace[input_idx]) else: input_trace.append(None) @@ -930,7 +1032,7 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): if len(free_dim) > 0: free_dim = [free_dim[0]] chunk_info = [chunk_info[0]] - possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) + possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -1130,6 +1232,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if node_idx in chunk_starts: within_chunk_region = True + region_idx = chunk_starts.index(node_idx) # add for loop chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] @@ -1150,7 +1253,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if node_idx in chunk_ends: body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) within_chunk_region = False - region_idx += 1 node_idx += 1 From 929445116a14d30ebbd50c5978a8f4db52ab3cd6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 10 Dec 2022 17:29:51 +0800 Subject: [PATCH 028/209] pass outproduct mean --- chunk_codegen.py | 317 +++++++++++++++++++++++++++++++---------------- 1 file changed, 212 insertions(+), 105 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index ce7d849178d1..fc3c88cf91f6 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -16,16 +16,31 @@ def _delete_free_var_from_last_use(user_to_last_uses): if n.op == 'placeholder': user_to_last_uses[key].remove(n) + def _get_node_shape(node): if hasattr(node.meta['tensor_meta'], "shape"): return node.meta['tensor_meta'].shape return None +def _is_non_compute_node(node): + if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + +def _is_non_compute_node_except_placeholder(node): + if any(i in node.op for i in ['get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm - self.nodes_list = list(gm.graph.nodes) + self.node_list = list(gm.graph.nodes) self.flow_trace = {} def _add_trace(self, name): @@ -49,7 +64,7 @@ def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depe raise RuntimeError("node not found") def _init_trace(self): - for i in self.nodes_list: + for i in self.node_list: if i.op == 'placeholder': self._add_trace(i.name) self._add_node(i.name, i) @@ -67,7 +82,7 @@ def _is_non_compute_node_except_placeholder(self, node): return False def _find_flow_for_node(self, node): - if type(self.nodes_list[0]) != type(node): + if type(self.node_list[0]) != type(node): return None if self._is_non_compute_node_except_placeholder(node): return None @@ -117,7 +132,7 @@ def trace_flow(self): # init trace self._init_trace() - for node in self.nodes_list: + for node in self.node_list: # skip if non compute node if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \ or self._is_non_compute_node(node): @@ -135,6 +150,41 @@ def trace_flow(self): else: self._add_outside_depend(node_domin_flow, node, arg, node_input_flow) return self.flow_trace + + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) + chunk_info = {'region': (start_idx, end_idx), + 'inputs': inputs, 'inputs_dim': start_dim, + 'outputs': outputs, 'outputs_dim': end_dim, + 'args': {}} + flow_flag = False + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_var = self.get_flow_mix(node) + if mix_flow_var is None: + continue + + # if there is a flow mix, op must be in [mul, add, div, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ['mul', 'add']): + for i in node.args: + if type(i) == type(mix_flow_var) and i != mix_flow_var: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO need to move that flow out of the chunk + if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: + flow_flag = True + for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var): + chunk_info['inputs'].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_flag = False + break + else: + raise NotImplementedError("%s not implemented" % node.name) + return flow_flag, chunk_info class IndexTracer(object): @@ -153,7 +203,7 @@ def _init_idx_trace_list(self): cur_trace = { 'idx': [None for _ in range(len(_get_node_shape(n)))], 'compute': [[] for _ in range(len(_get_node_shape(n)))], - 'source': [[] for _ in range(len(_get_node_shape(n)))], + 'source': [{} for _ in range(len(_get_node_shape(n)))], } else: cur_trace = {'idx': [], 'compute': [], 'source': []} @@ -178,7 +228,7 @@ def _del_dim(self, idx, dim_idx): def _add_dim(self, idx, dim_idx): self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index()) self.idx_trace_list[idx]['compute'].insert(dim_idx, []) - self.idx_trace_list[idx]['source'].insert(dim_idx, []) + self.idx_trace_list[idx]['source'].insert(dim_idx, {}) def _transform_index(self, node, node_dim): node_idx = self._find_idx_trace_from_node(node) @@ -192,10 +242,7 @@ def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): node_to_trace = self._find_trace_from_node(node_to) node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim] node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim]) - node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) - node_to_trace['source'][node_to_dim] = [] - node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) - node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) def _inherit_all_computation(self, node_from, node_to): node_from_compute = self._find_compute_trace_from_node(node_from) @@ -205,14 +252,16 @@ def _inherit_all_computation(self, node_from, node_to): self._add_source(node_from, i, node_to, i) node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - def _add_source(self, node_from, node_from_dim, node_to, node_to_dim): + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): node_from_dim = self._transform_index(node_from, node_from_dim) node_from_trace = self._find_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) node_to_trace = self._find_trace_from_node(node_to) node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) - node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) - node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + if init: + node_to_trace['source'][node_to_dim] = {} + node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim + node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim]) def _mark_computation_from_node(self, node_from, node_to, exclude=None): if exclude == None: @@ -485,11 +534,11 @@ def _assign_einsum_index(self, node, idx): source_idx = left_str.index(right_indice) self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx) - for i in sum_index: - for left_idx, left_str in enumerate(left): - if i in left_str: - self._mark_computation(node, idx, left_str.index(i)) - break + # for i in sum_index: + # for left_idx, left_str in enumerate(left): + # if i in left_str: + # self._mark_computation(node, idx, left_str.index(i)) + # break def _assign_softmax_index(self, node, idx): """ @@ -679,18 +728,56 @@ def trace_index(self): raise NotImplementedError(node.op, "op not implemented yet!") # self._merge_equal_idx() - def check_index(self, trace_idx, start_idx, end_idx): - for i in range(start_idx, end_idx + 1): - cur_idx = self.idx_trace_list[i]['idx'] - cur_compute = self.idx_trace_list[i]['compute'] - if trace_idx in cur_compute: - for j in cur_compute[trace_idx]: - if j < start_idx or j > end_idx: - return False - # same_idx = [1 if j == trace_idx else 0 for j in cur_idx] - # if sum(same_idx) > 1: - # return False + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list) + end_node_trace = self._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace['source'][end_dim] + sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and node_dim == start_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self._find_trace_from_node(end_node) + end_node_compute = end_node_trace['compute'][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False return True + # end_node_trace_source = end_node_trace['source'][end_dim] + # for node_idx, node_dim in end_node_trace_source.items(): + # if node_idx < start_node_idx or node_idx > end_node_idx: + # continue + # compute_list = self.idx_trace_list[node_idx]['compute'][node_dim] + # if any(start_node_idx <= i <= end_node_idx for i in compute_list): + # return False + # return True + class MemoryEstimator(object): def __init__(self) -> None: @@ -951,88 +1038,81 @@ def _is_not_compute(self, trace, chunk_range, dim_idx): return True return False - def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx): - inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) - chunk_info = {'inputs': inputs, 'outputs': outputs} - flow_flag = False - - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_var = self.flow_tracer.get_flow_mix(node) - if mix_flow_var is None: - continue - - # if there is a flow mix, op must be in [mul, add, div, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ['mul', 'add']): - for i in node.args: - if type(i) == type(mix_flow_var) and i != mix_flow_var: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO need to move that flow out of the chunk - if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: - flow_flag = True - for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var): - chunk_info['inputs'].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_flag = False - break - else: - raise NotImplementedError("%s not implemented" % node.name) - return flow_flag, chunk_info + def _check_duplicate_map(self, chunk_infos): + dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos] + remove_list = [] + for idx1, (input_dim1, output_dim1) in enumerate(dim_map): + for idx2, (input_dim2, output_dim2) in enumerate(dim_map): + if idx1 == idx2: + continue + # it means an index create 2 copy of itself + # eg. a = torch.matmul(x, x.transpose(-1, -2)) + # TODO currently remove it, deal with this in future + if input_dim1 == input_dim2 and output_dim1 != output_dim2: + remove_list.append(chunk_infos[idx1]) + remove_list.append(chunk_infos[idx2]) + for i in remove_list: + if i in chunk_infos: + chunk_infos.remove(i) + return chunk_infos def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): - before_trace = input_trace[start_idx] - after_trace = output_trace[end_idx] - free_dim = [] + start_traces = input_trace[start_idx] + end_trace = output_trace[end_idx] + end_node = self.node_list[end_idx] chunk_infos = [] - for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): - if not (before_trace['idx'][i] == after_trace['idx'][i] and - self._is_not_compute(before_trace, (start_idx, end_idx), i) and - self._is_not_compute(after_trace, (start_idx, end_idx), i) and - self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): - continue - if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx): + for end_dim, end_trace_idx in enumerate(end_trace['idx']): + if len(start_traces) > 1: + # TODO implement multi input chunk continue - flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) - if flow_flag == None: - continue - chunk_infos.append(chunk_info) - free_dim.append(i) - return free_dim, chunk_infos + for start_node, start_trace in start_traces.items(): + for start_dim, start_trace_idx in enumerate(start_trace['idx']): + # must be same trace idx + if start_trace_idx != end_trace_idx: + continue + # dim size cannot be 1 + if _get_node_shape(end_node)[end_dim] == 1 or \ + _get_node_shape(start_node)[start_dim] == 1: + continue + # check index source align + if not self.index_tracer.check_index_source( + start_dim, start_node, start_idx, end_dim, end_node): + continue + # check index copmute + if not self.index_tracer.check_index_compute( + start_idx, end_dim, end_node, end_idx): + continue + # detect flow meet + flow_flag, chunk_info = self.flow_tracer._detect_flow( + start_idx, start_dim, end_idx, end_dim) + if flow_flag: + continue + chunk_infos.append(chunk_info) + chunk_infos = self._check_duplicate_map(chunk_infos) + return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) - input_trace = [] - for i, n in enumerate(self.node_list): - if len(n.args) > 0 and n.op != 'output': - if isinstance(n.args[0], str): - input_idx = _find_idx_by_name(n.args[1].name, self.node_list) - else: - input_idx = _find_idx_by_name(n.args[0].name, self.node_list) - input_trace.append(output_trace[input_idx]) - else: - input_trace.append(None) - - for start_idx in range(max_chunk_region[0], peak_node): + input_trace = [] # trace of a node's input nodes + for _, n in enumerate(self.node_list): + cur_trace = {} + for arg in n.args: + if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg): + cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + input_trace.append(cur_trace) + + for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes - if any(op in ['placeholder', 'get_attr', 'output'] for op in - [self.node_list[start_idx].op, self.node_list[end_idx].op]): - continue - if any(any(i in name for i in ['getitem', 'getattr']) for name in - [self.node_list[start_idx].name, self.node_list[end_idx].name]): + if _is_non_compute_node(self.node_list[start_idx]) or \ + _is_non_compute_node(self.node_list[end_idx]): continue # select free dim - free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) - if len(free_dim) > 0: - free_dim = [free_dim[0]] - chunk_info = [chunk_info[0]] - possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) + chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) + if len(chunk_info) > 0: + possible_chunk_region.extend(chunk_info) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -1044,7 +1124,8 @@ def _search_best_chunk_region(self, possible_chunk_regions): max_region_range = i['region'][1] - i['region'][0] return best_regions - def _step_search(self, peak_node, active_node): + def _step_search(self, mem_peak, active_node): + peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region(active_node, peak_node) possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) @@ -1062,19 +1143,16 @@ def search_region(self): mem_peak = init_mem_peak while True: - peak_node = self._find_peak_node(mem_peak) - chunk_region = self._step_search(peak_node, active_node) - if chunk_region is None or len(chunk_region['dim']) == 0: + chunk_region = self._step_search(mem_peak, active_node) + if chunk_region is None: break chunk_regions.append(chunk_region) mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem( self.gm, [i['region'][0] for i in chunk_regions], - [i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions)) - + [i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions)) if self._stop_search(init_mem_peak, mem_peak): break - return chunk_regions @@ -1164,6 +1242,35 @@ def _find_input_and_output_nodes(nodes: List[Node]): return input_nodes, output_nodes +def _find_chunk_input_and_output_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + if input_node not in nodes and input_node not in input_nodes \ + and not _is_non_compute_node_except_placeholder(input_node): + input_nodes.append(input_node) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + # TODO it is unsafe to remove non compute node here + for node in nodes: + for output_node in node.users.keys(): + if output_node not in nodes and node not in output_nodes \ + and not _is_non_compute_node_except_placeholder(input_node): + output_nodes.append(node) + + return input_nodes, output_nodes + + def _find_idx_by_name(name, nodes_list): for idx, node in enumerate(nodes_list): if node.name == name: From d31e146687ebd4cefdc67500e84b7414b5760dd4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 10 Dec 2022 17:34:40 +0800 Subject: [PATCH 029/209] code format --- chunk_codegen.py | 908 +++++++++++++++++++++++++++++------------------ 1 file changed, 560 insertions(+), 348 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index fc3c88cf91f6..e8cf0d22f157 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -4,35 +4,52 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name -from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin -from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size +from torch.fx.graph import ( + _Namespace, + PythonCode, + _custom_builtins, + _is_from_torch, + _format_target, + magic_methods, + CodeGen, + _origin_type_map, + inplace_methods, + _CustomBuiltin, +) +from colossalai.fx.profiler import ( + calculate_fwd_out, + calculate_fwd_tmp, + parameter_size, + activation_size, +) + CODEGEN_AVAILABLE = True -__all__ = ['ChunkCodeGen'] +__all__ = ["ChunkCodeGen"] def _delete_free_var_from_last_use(user_to_last_uses): for key, value in user_to_last_uses.items(): for n in value: - if n.op == 'placeholder': + if n.op == "placeholder": user_to_last_uses[key].remove(n) def _get_node_shape(node): - if hasattr(node.meta['tensor_meta'], "shape"): - return node.meta['tensor_meta'].shape + if hasattr(node.meta["tensor_meta"], "shape"): + return node.meta["tensor_meta"].shape return None def _is_non_compute_node(node): - if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ - any(i in node.name for i in ['getitem', 'getattr']): + if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): return True return False - - + + def _is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ['get_attr', 'output']) or \ - any(i in node.name for i in ['getitem', 'getattr']): + if (any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"])): return True return False @@ -45,42 +62,48 @@ def __init__(self, gm) -> None: def _add_trace(self, name): self.flow_trace[name] = [] - + def _add_node(self, trace_name, node): - self.flow_trace[trace_name].append({'node': node, 'inside_depend': [], 'outside_depend': []}) - + self.flow_trace[trace_name].append( + {"node": node, "inside_depend": [], "outside_depend": []} + ) + def _add_inside_depend(self, flow_name, node, inside_depend_node): for i in self.flow_trace[flow_name]: - if i['node'] == node: - i['inside_depend'].append(inside_depend_node) + if i["node"] == node: + i["inside_depend"].append(inside_depend_node) return raise RuntimeError("node not found") - - def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depend_trace): + + def _add_outside_depend( + self, flow_name, node, outside_depend_node, outside_depend_trace + ): for i in self.flow_trace[flow_name]: - if i['node'] == node: - i['outside_depend'].append({outside_depend_trace: outside_depend_node}) + if i["node"] == node: + i["outside_depend"].append({outside_depend_trace: outside_depend_node}) return raise RuntimeError("node not found") def _init_trace(self): for i in self.node_list: - if i.op == 'placeholder': + if i.op == "placeholder": self._add_trace(i.name) self._add_node(i.name, i) def _is_non_compute_node(self, node): - if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ - any(i in node.name for i in ['getitem', 'getattr']): + if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): return True return False - + def _is_non_compute_node_except_placeholder(self, node): - if any(i in node.op for i in ['get_attr', 'output']) or \ - any(i in node.name for i in ['getitem', 'getattr']): + if any(i in node.op for i in ["get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): return True return False - + def _find_flow_for_node(self, node): if type(self.node_list[0]) != type(node): return None @@ -88,54 +111,57 @@ def _find_flow_for_node(self, node): return None for name, trace in self.flow_trace.items(): for i in trace: - if node == i['node']: + if node == i["node"]: return name if any(i in node.name for i in ["ones_like"]): self._add_trace(node.name) self._add_node(node.name, node) return node.name raise RuntimeError("node not found") - + def _find_first_valid_flow(self, flow): for i in flow: if i is not None: return i raise RuntimeError("invalid flow") - + def find_node_flow(self, node): for name, trace in self.flow_trace.items(): for i in trace: - if node == i['node']: + if node == i["node"]: return name, i raise RuntimeError("invalid node") - + def get_flow_mix(self, node): if self._is_non_compute_node(node): return None _, node_trace = self.find_node_flow(node) - if len(node_trace['outside_depend']) == 0: + if len(node_trace["outside_depend"]) == 0: return None - elif len(node_trace['outside_depend']) > 1: + elif len(node_trace["outside_depend"]) > 1: raise NotImplementedError - vars = list(node_trace['outside_depend'][0].values())[0] + vars = list(node_trace["outside_depend"][0].values())[0] return vars - + def get_same_flow_node(self, node_list, node): name, _ = self.find_node_flow(node) result = [] for i in self.flow_trace[name]: - if i['node'] in node_list: - result.append(i['node']) + if i["node"] in node_list: + result.append(i["node"]) return result - - def trace_flow(self): + + def trace_flow(self): # init trace self._init_trace() for node in self.node_list: # skip if non compute node - if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \ - or self._is_non_compute_node(node): + if all( + type(arg) != type(node) + or self._is_non_compute_node_except_placeholder(arg) + for arg in node.args + ) or self._is_non_compute_node(node): continue node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] @@ -148,35 +174,45 @@ def trace_flow(self): elif node_input_flow == node_domin_flow: self._add_inside_depend(node_domin_flow, node, arg) else: - self._add_outside_depend(node_domin_flow, node, arg, node_input_flow) + self._add_outside_depend( + node_domin_flow, node, arg, node_input_flow + ) return self.flow_trace - + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) - chunk_info = {'region': (start_idx, end_idx), - 'inputs': inputs, 'inputs_dim': start_dim, - 'outputs': outputs, 'outputs_dim': end_dim, - 'args': {}} + inputs, outputs = _find_chunk_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_dim": start_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "args": {}, + } flow_flag = False - + for idx in range(start_idx, end_idx + 1): node = self.node_list[idx] mix_flow_var = self.get_flow_mix(node) if mix_flow_var is None: continue - + # if there is a flow mix, op must be in [mul, add, div, matmul] # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ['mul', 'add']): + if any(n in node.name for n in ["mul", "add"]): for i in node.args: if type(i) == type(mix_flow_var) and i != mix_flow_var: main_flow_var = i - # if mix flow is a broadcast in chunk dim, + # if mix flow is a broadcast in chunk dim, # TODO need to move that flow out of the chunk - if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: + if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1: flow_flag = True - for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var): - chunk_info['inputs'].remove(i) + for i in self.get_same_flow_node( + chunk_info["inputs"], mix_flow_var + ): + chunk_info["inputs"].remove(i) # else, we need to chunk mix var as well else: # TODO chunk another value @@ -199,51 +235,53 @@ def __init__(self, gm) -> None: def _init_idx_trace_list(self): idx_trace_list = [] for n in self.nodes_list: - if _get_node_shape(n) != None: + if _get_node_shape(n) != None: cur_trace = { - 'idx': [None for _ in range(len(_get_node_shape(n)))], - 'compute': [[] for _ in range(len(_get_node_shape(n)))], - 'source': [{} for _ in range(len(_get_node_shape(n)))], + "idx": [None for _ in range(len(_get_node_shape(n)))], + "compute": [[] for _ in range(len(_get_node_shape(n)))], + "source": [{} for _ in range(len(_get_node_shape(n)))], } else: - cur_trace = {'idx': [], 'compute': [], 'source': []} + cur_trace = {"idx": [], "compute": [], "source": []} idx_trace_list.append(cur_trace) return idx_trace_list - + def _add_index(self): """ Update the count and return it. To record the idx number. - + Returns: idx_count: int - """ + """ self.idx_count += 1 return self.idx_count - + def _del_dim(self, idx, dim_idx): - self.idx_trace_list[idx]['idx'].pop(dim_idx) - self.idx_trace_list[idx]['compute'].pop(dim_idx) - self.idx_trace_list[idx]['source'].pop(dim_idx) - + self.idx_trace_list[idx]["idx"].pop(dim_idx) + self.idx_trace_list[idx]["compute"].pop(dim_idx) + self.idx_trace_list[idx]["source"].pop(dim_idx) + def _add_dim(self, idx, dim_idx): - self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index()) - self.idx_trace_list[idx]['compute'].insert(dim_idx, []) - self.idx_trace_list[idx]['source'].insert(dim_idx, {}) - + self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index()) + self.idx_trace_list[idx]["compute"].insert(dim_idx, []) + self.idx_trace_list[idx]["source"].insert(dim_idx, {}) + def _transform_index(self, node, node_dim): node_idx = self._find_idx_trace_from_node(node) dims = list(range(len(node_idx))) return dims[node_dim] - + def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): node_from_dim = self._transform_index(node_from, node_from_dim) node_to_dim = self._transform_index(node_to, node_to_dim) node_from_trace = self._find_trace_from_node(node_from) node_to_trace = self._find_trace_from_node(node_to) - node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim] - node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim]) + node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] + node_to_trace["compute"][node_to_dim] = copy.deepcopy( + node_from_trace["compute"][node_from_dim] + ) self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) - + def _inherit_all_computation(self, node_from, node_to): node_from_compute = self._find_compute_trace_from_node(node_from) node_to_compute = self._find_compute_trace_from_node(node_to) @@ -251,7 +289,7 @@ def _inherit_all_computation(self, node_from, node_to): for i in range(len(node_from_compute)): self._add_source(node_from, i, node_to, i) node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): node_from_dim = self._transform_index(node_from, node_from_dim) node_from_trace = self._find_trace_from_node(node_from) @@ -259,10 +297,12 @@ def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False node_to_trace = self._find_trace_from_node(node_to) node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) if init: - node_to_trace['source'][node_to_dim] = {} - node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim - node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim]) - + node_to_trace["source"][node_to_dim] = {} + node_to_trace["source"][node_to_dim][node_from_idx] = node_from_dim + node_to_trace["source"][node_to_dim].update( + node_from_trace["source"][node_from_dim] + ) + def _mark_computation_from_node(self, node_from, node_to, exclude=None): if exclude == None: exclude = [] @@ -278,7 +318,7 @@ def _mark_computation_from_node(self, node_from, node_to, exclude=None): for j in node_from_compute[i]: if j not in node_to_compute[i]: node_to_compute[i].append(j) - + def _mark_idx_equal(self, node1, dim1, node2, dim2): """ Mark 2 index to be equal. @@ -293,7 +333,7 @@ def _mark_idx_equal(self, node1, dim1, node2, dim2): # self._add_source(node2, dim2, node1, dim1) # else: # self._add_source(node1, dim1, node2, dim2) - + def _mark_computation(self, node, idx, dim): """ Mark some dims of node as computed. @@ -302,14 +342,14 @@ def _mark_computation(self, node, idx, dim): node (node) idx (int): node index dim (list or int): dims to be marked as computed - """ + """ if isinstance(dim, int): dim = [dim] dims = list(range(len(_get_node_shape(node)))) for d in dim: cur_dim = dims[d] - if idx not in self.idx_trace_list[idx]['compute'][cur_dim]: - self.idx_trace_list[idx]['compute'][cur_dim].append(idx) + if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: + self.idx_trace_list[idx]["compute"][cur_dim].append(idx) def _find_trace_from_node(self, node): """ @@ -320,11 +360,11 @@ def _find_trace_from_node(self, node): Returns: idx (list): idx of the node compute (list): computed idx of the node. - """ + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) node_dict = self.idx_trace_list[node_idx] return node_dict - + def _find_idx_trace_from_node(self, node): """ Find node idx trace by the node. @@ -333,10 +373,10 @@ def _find_idx_trace_from_node(self, node): node (node) Returns: idx (list): idx of the node - """ + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) - return self.idx_trace_list[node_idx]['idx'] - + return self.idx_trace_list[node_idx]["idx"] + def _find_compute_trace_from_node(self, node): """ Find node compute trace by the node. @@ -345,10 +385,10 @@ def _find_compute_trace_from_node(self, node): node (node) Returns: compute (list): computed idx of the node. - """ + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) - return self.idx_trace_list[node_idx]['compute'] - + return self.idx_trace_list[node_idx]["compute"] + def _assign_index_as_input(self, node, node_idx, input_node=None): """ Assign node's trace as its input node. @@ -360,13 +400,13 @@ def _assign_index_as_input(self, node, node_idx, input_node=None): if input_node == None: input_node = node.args[0] input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) - input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] - + input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] + new_idx_trace = copy.deepcopy(input_node_idx_trace) - self.idx_trace_list[node_idx]['idx'] = new_idx_trace - + self.idx_trace_list[node_idx]["idx"] = new_idx_trace + self._inherit_all_computation(input_node, node) - + def _assign_all_index(self, node, node_idx): """ Add new index for all node's dims. @@ -374,12 +414,12 @@ def _assign_all_index(self, node, node_idx): Args: node (node) node_idx (int) - """ - shape = node.meta['tensor_meta'].shape + """ + shape = node.meta["tensor_meta"].shape new_trace = [] for _ in shape: new_trace.append(self._add_index()) - self.idx_trace_list[node_idx]['idx'] = new_trace + self.idx_trace_list[node_idx]["idx"] = new_trace def _assign_transpose_index(self, node, node_idx): """ @@ -390,14 +430,14 @@ def _assign_transpose_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ input_node = node.args[0] tranpose_dim = node.args[1:] - + self._assign_index_as_input(node, node_idx, input_node) self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) - + def _assign_permute_index(self, node, node_idx): """ Assign index for permute op. @@ -407,14 +447,14 @@ def _assign_permute_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ permute_dim = node.args[1:] input_node = node.args[0] - + self._assign_index_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): self._inherit_index(input_node, d, node, idx) - + def _assign_linear_index(self, node, node_idx): """ Assign index for linear op. @@ -431,13 +471,13 @@ def _assign_linear_index(self, node, node_idx): bias = None else: input_node, weight, bias = node.args - + self._assign_index_as_input(node, node_idx) self._inherit_index(weight, 1, node, -1) self._mark_computation(node, node_idx, [-1]) self._mark_idx_equal(input_node, -1, weight, 0) - + if bias: self._mark_idx_equal(input_node, -1, bias, 0) @@ -451,10 +491,10 @@ def _assign_matmul_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ matmul_left, matmul_right = node.args - - assert(len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right))) + + assert len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right)) self._assign_index_as_input(node, node_idx, matmul_left) self._inherit_index(matmul_right, -1, node, -1) @@ -474,7 +514,7 @@ def _assign_layernorm_index(self, node, idx): """ self._assign_index_as_input(node, idx) self._mark_computation(node, idx, [-1, -2]) - + def _assign_elementwise_index(self, node, idx): """ Assign index for element-wise op (eg. relu sigmoid add mul). @@ -484,7 +524,7 @@ def _assign_elementwise_index(self, node, idx): Args: node (node) node_idx (int) - """ + """ self._assign_index_as_input(node, idx) nodes_in = [] for node_in in node.args: @@ -498,13 +538,13 @@ def _assign_elementwise_index(self, node, idx): for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): if node_in0_shape[i] == node_in1_shape[i]: self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) - + def _assgin_no_change_index(self, node, idx): self._assign_index_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): self._mark_computation_from_node(node_in, node) - + def _assign_einsum_index(self, node, idx): """ Assign index for einsum op. @@ -515,11 +555,11 @@ def _assign_einsum_index(self, node, idx): """ patterns = node.args[0] input_nodes = node.args[1:] - + patterns = patterns.replace(" ", "") left, right = patterns.split("->") left = left.split(",") - + all_index = [] for i in left: for c in i: @@ -527,19 +567,21 @@ def _assign_einsum_index(self, node, idx): all_index = set(all_index) free_index = set([i for i in right]) sum_index = all_index - free_index - + for right_idx, right_indice in enumerate(right): for left_idx, left_str in enumerate(left): if right_indice in left_str: source_idx = left_str.index(right_indice) - self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx) - + self._inherit_index( + input_nodes[left_idx], source_idx, node, right_idx + ) + # for i in sum_index: # for left_idx, left_str in enumerate(left): # if i in left_str: # self._mark_computation(node, idx, left_str.index(i)) # break - + def _assign_softmax_index(self, node, idx): """ Assign index for softmax op. @@ -549,10 +591,10 @@ def _assign_softmax_index(self, node, idx): Args: node (node) node_idx (int) - """ + """ self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [node.kwargs['dim']]) - + self._mark_computation(node, idx, [node.kwargs["dim"]]) + def _assign_unsqueeze_index(self, node, node_idx): """ Assign index for unsqueeze op. @@ -564,10 +606,10 @@ def _assign_unsqueeze_index(self, node, node_idx): """ self._del_dim(node_idx, -1) self._assign_index_as_input(node, node_idx) - self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) - self.idx_trace_list[node_idx]['compute'].insert(node.args[1], []) - self.idx_trace_list[node_idx]['source'].insert(node.args[1], []) - + self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index()) + self.idx_trace_list[node_idx]["compute"].insert(node.args[1], []) + self.idx_trace_list[node_idx]["source"].insert(node.args[1], []) + def _assign_dropout_index(self, node, node_idx): """ Assign index for unsqueeze op. @@ -576,9 +618,9 @@ def _assign_dropout_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ self._assign_index_as_input(node, node_idx) - + def _assign_ones_like_index(self, node, node_idx): """ Assign index for oneslike op. @@ -587,7 +629,7 @@ def _assign_ones_like_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ self._assign_all_index(node, node_idx) def _assign_view_reshape_index(self, node, node_idx): @@ -604,16 +646,16 @@ def _assign_view_reshape_index(self, node, node_idx): Args: node (node) node_idx (int) - """ + """ # get data, turn into number origin_node = node.args[0] - origin_shape = origin_node.meta['tensor_meta'].shape + origin_shape = origin_node.meta["tensor_meta"].shape target_shape = [] for i in range(1, len(node.args)): if isinstance(node.args[i], int): target_shape.append(node.args[i]) else: - target_shape.append(node.args[i].meta['fwd_out'][0]) + target_shape.append(node.args[i].meta["fwd_out"][0]) # compute the value of -1 if -1 in target_shape: @@ -641,7 +683,13 @@ def _assign_view_reshape_index(self, node, node_idx): dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] self._del_dim(node_idx, -1) else: - raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") + raise NotImplementedError( + "shape" + + str(origin_shape) + + "and" + + str(target_shape) + + "view not implemented" + ) # get new index origin_trace = self._find_idx_trace_from_node(origin_node) @@ -651,7 +699,7 @@ def _assign_view_reshape_index(self, node, node_idx): self._del_dim(node_idx, i) for i in dim_to: self._add_dim(node_idx, i) - + # inherit computation compute_log = self._find_compute_trace_from_node(origin_node) for i in dim_from: @@ -659,13 +707,15 @@ def _assign_view_reshape_index(self, node, node_idx): for j in dim_to: self._mark_computation(node, node_idx, [j]) break - + # log view, not used now - view_dict = {"idx_from": [origin_trace[i] for i in dim_from], - "dim_from": dim_from, - "idx_to": [self.idx_trace_list[node_idx]['idx'][i] for i in dim_to], - "dim_to": dim_to} - self.idx_view_list.append(view_dict) + view_dict = { + "idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], + "dim_to": dim_to, + } + self.idx_view_list.append(view_dict) def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) @@ -674,60 +724,64 @@ def _merge_equal_idx(self): merge_to = min(idx) merge_from = max(idx) for trace in self.idx_trace_list: - if merge_from in trace['idx']: - trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] - + if merge_from in trace["idx"]: + trace["idx"] = [ + merge_to if i == merge_from else i for i in trace["idx"] + ] + def trace_index(self): for idx, node in enumerate(self.nodes_list): - if node.op == 'placeholder': + if node.op == "placeholder": self._assign_all_index(node, idx) - elif node.op == 'call_method': - if 'transpose' in node.name: + elif node.op == "call_method": + if "transpose" in node.name: self._assign_transpose_index(node, idx) - elif 'permute' in node.name: + elif "permute" in node.name: self._assign_permute_index(node, idx) - elif 'view' in node.name or 'reshape' in node.name: + elif "view" in node.name or "reshape" in node.name: self._assign_view_reshape_index(node, idx) - elif 'unsqueeze' in node.name: + elif "unsqueeze" in node.name: self._assign_unsqueeze_index(node, idx) - elif any(i in node.name for i in ['to', 'contiguous']): + elif any(i in node.name for i in ["to", "contiguous"]): self._assgin_no_change_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") - elif node.op == 'call_function': - if 'linear' in node.name: + elif node.op == "call_function": + if "linear" in node.name: self._assign_linear_index(node, idx) - elif 'matmul' in node.name: + elif "matmul" in node.name: self._assign_matmul_index(node, idx) - elif 'softmax' in node.name: + elif "softmax" in node.name: self._assign_softmax_index(node, idx) - elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): + elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): self._assign_elementwise_index(node, idx) - elif 'ones_like' in node.name: + elif "ones_like" in node.name: self._assign_ones_like_index(node, idx) - elif 'dropout' in node.name: + elif "dropout" in node.name: self._assign_dropout_index(node, idx) - elif 'einsum' in node.name: + elif "einsum" in node.name: self._assign_einsum_index(node, idx) - elif 'getattr' in node.name: - continue # get attr like shape - elif 'getitem' in node.name: - continue # get item in list + elif "getattr" in node.name: + continue # get attr like shape + elif "getitem" in node.name: + continue # get item in list else: - raise NotImplementedError(node.name, "function not implemented yet!") - elif node.op == 'call_module': - if any(n in node.name for n in ['layernorm', 'norm']): + raise NotImplementedError( + node.name, "function not implemented yet!" + ) + elif node.op == "call_module": + if any(n in node.name for n in ["layernorm", "norm"]): self._assign_layernorm_index(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") - elif node.op == 'get_attr': - self._assign_all_index(node, idx) # get param - elif node.op == 'output': + elif node.op == "get_attr": + self._assign_all_index(node, idx) # get param + elif node.op == "output": continue else: raise NotImplementedError(node.op, "op not implemented yet!") # self._merge_equal_idx() - + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): """ Check 2 given index: one index should be source of the other @@ -742,8 +796,10 @@ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node """ start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list) end_node_trace = self._find_trace_from_node(end_node) - end_node_trace_source = end_node_trace['source'][end_dim] - sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted( + end_node_trace_source.items(), key=lambda d: d[0], reverse=True + ) for node_idx, node_dim in sorted_source: if node_idx == start_node_idx and node_dim == start_dim: return True @@ -765,7 +821,7 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx): bool: True if check pass """ end_node_trace = self._find_trace_from_node(end_node) - end_node_compute = end_node_trace['compute'][end_dim] + end_node_compute = end_node_trace["compute"][end_dim] if any(start_idx <= i <= end_idx for i in end_node_compute): return False return True @@ -784,19 +840,23 @@ def __init__(self) -> None: pass def _get_meta_node_size(self, x): - x = x.meta['tensor_meta'] + x = x.meta["tensor_meta"] x = x.numel * torch.tensor([], dtype=x.dtype).element_size() return x def _get_output_node(self, n): - fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + fwd_out = { + x.uuid: x + for x in n.meta["fwd_out"] + if isinstance(x, torch.Tensor) and hasattr(x, "uuid") + } out_size = activation_size(fwd_out) out_node = [n.name] if out_size > 0 else [] return out_size, out_node - + def _get_output_node_size(self, n): return self._get_output_node(n)[0] - + def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] for i in new_active: @@ -806,7 +866,7 @@ def _add_active_node(self, n, active_list): def _get_delete_node(self, user, user_to_last_uses): delete_size = 0 delete_node = [] - if user.op not in ('placeholder', 'output'): + if user.op not in ("placeholder", "output"): nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): out_node = [self._get_output_node(i) for i in nodes_to_delete] @@ -814,13 +874,13 @@ def _get_delete_node(self, user, user_to_last_uses): for i in range(len(out_node)): if out_node[i][0] > 0: delete_node.append(out_node[i][1][0]) - elif nodes_to_delete[i].op == 'placeholder': + elif nodes_to_delete[i].op == "placeholder": delete_node.append(nodes_to_delete[i].name) return delete_size, delete_node - + def _get_delete_node_size(self, user, user_to_last_uses): return self._get_delete_node(user, user_to_last_uses)[0] - + def _remove_deactive_node(self, user, user_to_last_uses, active_list): delete_node = self._get_delete_node(user, user_to_last_uses)[1] for i in delete_node: @@ -842,20 +902,24 @@ def register_last_uses(n: Node, user: Node): def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): mem = 0 - not_contiguous_ops = ['transpose', 'permute'] + not_contiguous_ops = ["transpose", "permute"] - if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']): + if node.op == "call_function" and any( + n in node.name for n in ["matmul", "reshape"] + ): for n in node.args: if n in not_contiguous_list: # matmul won't change origin tensor, but create a tmp copy mem += self._get_output_node_size(n) - elif node.op == 'call_module': + elif node.op == "call_module": for n in node.args: if n in not_contiguous_list: # module will just make origin tensor to contiguous if delete: not_contiguous_list.remove(n) - elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): + elif node.op == "call_method" and any( + i in node.name for i in not_contiguous_ops + ): if node not in not_contiguous_list: not_contiguous_list.append(node) elif any(i in node.args for i in not_contiguous_list): @@ -865,13 +929,14 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): return mem def _get_chunk_ratio(self, node, chunk_dim, chunk_size): - shape = node.meta['tensor_meta'].shape + shape = node.meta["tensor_meta"].shape chunk_ratio = float(chunk_size) / shape[chunk_dim] return chunk_ratio - - def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): - if user.op in ('placeholder', 'output'): + def _get_chunk_delete_node_size( + self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node + ): + if user.op in ("placeholder", "output"): return 0 nodes_to_delete = user_to_last_uses.get(user, []) delete_size = 0 @@ -881,12 +946,11 @@ def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node delete_size += self._get_output_node_size(n) * chunk_ratio return delete_size - def _print_mem_log(self, log, nodes, title=None): if title: print(title) for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end='') + print("%s:%.2f \t" % (n.name, l), end="") if (idx + 1) % 3 == 0: print("") print("\n") @@ -895,16 +959,23 @@ def _print_compute_op_mem_log(self, log, nodes, title=None): if title: print(title) for idx, (l, n) in enumerate(zip(log, nodes)): - if n.op in ['placeholder', 'get_attr', 'output']: + if n.op in ["placeholder", "get_attr", "output"]: continue - if any(i in n.name for i in ['getitem', 'getattr']): + if any(i in n.name for i in ["getitem", "getattr"]): continue - print("%s:%.2f \t" % (n.name, l), end='') + print("%s:%.2f \t" % (n.name, l), end="") if (idx + 1) % 3 == 0: print("") print("\n") - - def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None): + + def estimate_chunk_inference_mem( + self, + gm: torch.fx.GraphModule, + start_nodes=None, + end_nodes=None, + chunk_dims=None, + chunk_sizes=None, + ): act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] @@ -915,42 +986,65 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non user_to_last_uses = self._get_last_usr(node_list) user_to_last_uses_no_free_var = self._get_last_usr(node_list) _delete_free_var_from_last_use(user_to_last_uses_no_free_var) - - use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]) + + use_chunk = all( + i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes] + ) chunk_within = False chunk_region_idx = 0 - chunk_ratio = 1 # use it to estimate chunk mem + chunk_ratio = 1 # use it to estimate chunk mem for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in start_nodes: chunk_within = True - chunk_ratio = self._get_chunk_ratio(node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]) - act_memory += self._get_output_node_size(node_list[end_nodes[chunk_region_idx]]) / (1024 ** 2) - + chunk_ratio = self._get_chunk_ratio( + node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx] + ) + act_memory += self._get_output_node_size( + node_list[end_nodes[chunk_region_idx]] + ) / (1024**2) + # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2) + if node.op == "placeholder": + act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) act_memory_peak_log.append(act_memory) active_node_list.append(node.name) # skip output - elif node.op == 'output': + elif node.op == "output": continue # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) - act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + act_memory += ( + self._get_contiguous_memory(node, not_contiguous_list) + * chunk_ratio + / (1024**2) + ) + act_memory += ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) + act_memory -= ( + self._get_contiguous_memory(node, not_contiguous_list, delete=True) + * chunk_ratio + / (1024**2) + ) if chunk_within: act_memory -= self._get_chunk_delete_node_size( - node, user_to_last_uses_no_free_var, chunk_ratio, node_list, - start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2) + node, + user_to_last_uses_no_free_var, + chunk_ratio, + node_list, + start_nodes[chunk_region_idx], + end_nodes[chunk_region_idx], + ) / (1024**2) else: - act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2) + act_memory -= self._get_delete_node_size( + node, user_to_last_uses_no_free_var + ) / (1024**2) # log active node self._add_active_node(node, active_node_list) @@ -958,11 +1052,13 @@ def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=Non # if node in chunk end nodes, restore chunk settings if use_chunk and idx in end_nodes: - act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + act_memory -= ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) chunk_within = False chunk_ratio = 1 chunk_region_idx += 1 - + act_memory_after_node_log.append(act_memory) active_node_list_log.append(copy.deepcopy(active_node_list)) @@ -991,14 +1087,14 @@ def _find_peak_node(self, mem_peak): max_value = max(mem_peak) max_idx = mem_peak.index(max_value) return max_idx - + def _get_free_var(self): free_var_idx = [] for idx, n in enumerate(self.node_list): - if n.op == 'placeholder': + if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx - + def _get_min_free_var(self, active_node_list, free_vars): min_len = 999 for idx, n in enumerate(active_node_list): @@ -1007,11 +1103,11 @@ def _get_min_free_var(self, active_node_list, free_vars): if len(n) < min_len: min_len = len(n) return min_len - + def _search_max_chunk_region(self, active_node, peak_node): free_vars = self._get_free_var() min_var = self._get_min_free_var(active_node, free_vars) - + # from peak_node to free_var chunk_region_start = None for i in range(peak_node, -1, -1): @@ -1029,17 +1125,19 @@ def _search_max_chunk_region(self, active_node, peak_node): if i in free_vars or i == 0: raise RuntimeError() return chunk_region_start, chunk_region_end - + def _is_not_compute(self, trace, chunk_range, dim_idx): - if trace['idx'][dim_idx] not in trace['compute']: + if trace["idx"][dim_idx] not in trace["compute"]: return True - if trace['idx'][dim_idx] in trace['compute'] and \ - all(i < chunk_range[0] or i > chunk_range[1] for i in trace['compute'][trace['idx'][dim_idx]]): + if trace["idx"][dim_idx] in trace["compute"] and all( + i < chunk_range[0] or i > chunk_range[1] + for i in trace["compute"][trace["idx"][dim_idx]] + ): return True return False - + def _check_duplicate_map(self, chunk_infos): - dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos] + dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos] remove_list = [] for idx1, (input_dim1, output_dim1) in enumerate(dim_map): for idx2, (input_dim2, output_dim2) in enumerate(dim_map): @@ -1055,36 +1153,41 @@ def _check_duplicate_map(self, chunk_infos): if i in chunk_infos: chunk_infos.remove(i) return chunk_infos - + def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] end_node = self.node_list[end_idx] chunk_infos = [] - for end_dim, end_trace_idx in enumerate(end_trace['idx']): + for end_dim, end_trace_idx in enumerate(end_trace["idx"]): if len(start_traces) > 1: # TODO implement multi input chunk continue for start_node, start_trace in start_traces.items(): - for start_dim, start_trace_idx in enumerate(start_trace['idx']): + for start_dim, start_trace_idx in enumerate(start_trace["idx"]): # must be same trace idx if start_trace_idx != end_trace_idx: continue # dim size cannot be 1 - if _get_node_shape(end_node)[end_dim] == 1 or \ - _get_node_shape(start_node)[start_dim] == 1: + if ( + _get_node_shape(end_node)[end_dim] == 1 + or _get_node_shape(start_node)[start_dim] == 1 + ): continue # check index source align if not self.index_tracer.check_index_source( - start_dim, start_node, start_idx, end_dim, end_node): + start_dim, start_node, start_idx, end_dim, end_node + ): continue # check index copmute if not self.index_tracer.check_index_compute( - start_idx, end_dim, end_node, end_idx): + start_idx, end_dim, end_node, end_idx + ): continue # detect flow meet flow_flag, chunk_info = self.flow_tracer._detect_flow( - start_idx, start_dim, end_idx, end_dim) + start_idx, start_dim, end_idx, end_dim + ) if flow_flag: continue chunk_infos.append(chunk_info) @@ -1098,59 +1201,78 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): for _, n in enumerate(self.node_list): cur_trace = {} for arg in n.args: - if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg): + if type(arg) == type(n) and not _is_non_compute_node_except_placeholder( + arg + ): cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) input_trace.append(cur_trace) for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes - if _is_non_compute_node(self.node_list[start_idx]) or \ - _is_non_compute_node(self.node_list[end_idx]): + if _is_non_compute_node( + self.node_list[start_idx] + ) or _is_non_compute_node(self.node_list[end_idx]): continue - + # select free dim - chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) + chunk_info = self._find_free_dim( + input_trace, output_trace, start_idx, end_idx + ) if len(chunk_info) > 0: possible_chunk_region.extend(chunk_info) return possible_chunk_region - + def _search_best_chunk_region(self, possible_chunk_regions): max_region_range = 0 best_regions = None for i in possible_chunk_regions: - if i['region'][1] - i['region'][0] > max_region_range: + if i["region"][1] - i["region"][0] > max_region_range: best_regions = i - max_region_range = i['region'][1] - i['region'][0] + max_region_range = i["region"][1] - i["region"][0] return best_regions - + def _step_search(self, mem_peak, active_node): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region(active_node, peak_node) - possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + possible_chunk_regions = self._search_possible_chunk_regions( + max_chunk_region, peak_node + ) best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) return best_chunk_region - + def _stop_search(self, init_mem_peak, mem_peak): sorted_init_mem_peak = sorted(init_mem_peak) if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: return True return False - + def search_region(self): chunk_regions = [] - init_mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + ( + init_mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm) mem_peak = init_mem_peak - + while True: chunk_region = self._step_search(mem_peak, active_node) if chunk_region is None: break - + chunk_regions.append(chunk_region) - mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem( - self.gm, [i['region'][0] for i in chunk_regions], - [i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions)) + ( + mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.gm, + [i["region"][0] for i in chunk_regions], + [i["region"][1] for i in chunk_regions], + [i["inputs_dim"] for i in chunk_regions], + [1] * len(chunk_regions), + ) if self._stop_search(init_mem_peak, mem_peak): break return chunk_regions @@ -1180,18 +1302,24 @@ def _get_first_non_single_dim(shape): def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2): if len(chunk_input_meta) == 1: node = chunk_input_meta[0] - node_shape = node.meta['tensor_meta'].shape - free_shape = [node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))] + node_shape = node.meta["tensor_meta"].shape + free_shape = [ + node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) + ] chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) - out_shape = str(list(chunk_output.meta['tensor_meta'].shape)) - - context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % ( - out_shape, node.name, node.name, chunk_size) + out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) + + context = ( + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" + % (out_shape, node.name, node.name, chunk_size) + ) context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) else: - raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta)) + raise NotImplementedError( + "input with size %d not implemented" % len(chunk_input_meta) + ) return context @@ -1199,17 +1327,27 @@ def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim): chunk_inputs_name = chunk_inputs[0].name chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) - chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape - free_shape = [chunk_output_shape[i] if i in chunk_dim else 1 for i in range(len(chunk_output_shape))] + chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape + free_shape = [ + chunk_output_shape[i] if i in chunk_dim else 1 + for i in range(len(chunk_output_shape)) + ] chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) - context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - + context += ( + chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + ) + # determine if its the last use for chunk input users_name = list(chunk_inputs[0].users.keys()) - if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]): + if all( + [ + _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx + for user in users_name + ] + ): context += "; %s = None" % chunk_inputs_name context += "\n" @@ -1255,8 +1393,11 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]): # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if input_node not in nodes and input_node not in input_nodes \ - and not _is_non_compute_node_except_placeholder(input_node): + if ( + input_node not in nodes + and input_node not in input_nodes + and not _is_non_compute_node_except_placeholder(input_node) + ): input_nodes.append(input_node) # if a node has a user node which is not in the node list @@ -1264,8 +1405,11 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]): # TODO it is unsafe to remove non compute node here for node in nodes: for output_node in node.users.keys(): - if output_node not in nodes and node not in output_nodes \ - and not _is_non_compute_node_except_placeholder(input_node): + if ( + output_node not in nodes + and node not in output_nodes + and not _is_non_compute_node_except_placeholder(input_node) + ): output_nodes.append(node) return input_nodes, output_nodes @@ -1288,7 +1432,15 @@ def _replace_name(context, name_from, name_to): return context -def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): +def emit_code_with_chunk( + body, + ckpt_func, + nodes, + emit_node_func, + delete_unused_value_func, + meta_nodes, + meta_graph, +): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -1304,14 +1456,14 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # find the offload regions chunk_region_search = ChunkRegionSearch(meta_graph) chunk_search = chunk_region_search.search_region() - chunk_regions = [i['region'] for i in chunk_search] - chunk_dims = [i['dim'] for i in chunk_search] - chunk_infos = [i['chunk_info'] for i in chunk_search] - + chunk_regions = [i["region"] for i in chunk_search] + chunk_dims = [i["dim"] for i in chunk_search] + chunk_infos = [i["chunk_info"] for i in chunk_search] + chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] - chunk_inputs = [[j['inputs'][0] for j in i] for i in chunk_infos] - chunk_outputs = [[j['outputs'][0] for j in i] for i in chunk_infos] + chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos] + chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos] within_chunk_region = False node_list = list(nodes) @@ -1322,14 +1474,18 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # inputs, outputs = _find_input_and_output_nodes(offload_node_list) # chunk_inputs.append(inputs) # chunk_outputs.append(outputs) - - chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] - chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] + + chunk_inputs_idx = [ + [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs + ] + chunk_outputs_idx = [ + [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs + ] chunk_inputs_names = [] for i in chunk_inputs: for j in i: chunk_inputs_names.append(j.name) - + # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func node_idx = 0 @@ -1340,16 +1496,24 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) - + # add for loop chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] - body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]], chunk_dims[region_idx])) + body.append( + _gen_loop_start( + chunk_input_meta, + node_list[chunk_ends[region_idx]], + chunk_dims[region_idx], + ) + ) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor') - body[-1] = ' ' + body[-1] + body[-1] = _replace_name( + body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor" + ) + body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: @@ -1358,7 +1522,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v delete_unused_value_func(node, body, chunk_inputs_names) if node_idx in chunk_ends: - body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) + body.append( + _gen_loop_end( + node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx] + ) + ) within_chunk_region = False node_idx += 1 @@ -1372,14 +1540,16 @@ def __init__(self, meta_graph): self.meta_graph = meta_graph self.meta_node = list(meta_graph.graph.nodes) - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -1389,7 +1559,9 @@ def add_global(name_hint: str, obj: Any): Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -1405,7 +1577,9 @@ def add_global(name_hint: str, obj: Any): return global_name # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + _custom_builtins["colossalai"] = _CustomBuiltin( + "import colossalai", colossalai + ) # Pre-fill the globals table with registered builtins. for name, (_, obj) in _custom_builtins.items(): @@ -1414,16 +1588,16 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -1441,20 +1615,21 @@ def type_repr(o: Any): # Common case: this is a regular module name like 'foo.bar.baz' return add_global(typename, o) - def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -1472,9 +1647,9 @@ def register_last_uses(n: Node, user: Node): for node in reversed(nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - + _delete_free_var_from_last_use(user_to_last_uses) - + # NOTE: we add a variable to distinguish body and ckpt_func def delete_unused_values(user: Node, body, to_keep=[]): """ @@ -1482,103 +1657,140 @@ def delete_unused_values(user: Node, body, to_keep=[]): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = ( + "" if not node.args else f" = {repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph) + emit_code_with_chunk( + body, + ckpt_func, + nodes, + emit_node, + delete_unused_values, + self.meta_node, + self.meta_graph, + ) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join( + [f'{wrap_name}("{name}")' for name in wrapped_fns] + ) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -1589,15 +1801,15 @@ def emit_node(node: Node, body): # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} -{code}""" +{code}""" print(fn_code) return PythonCode(fn_code, globals_) From 5de9e46381f35a40ffff3675c2170a987b6fd9b9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 10 Dec 2022 17:34:48 +0800 Subject: [PATCH 030/209] code format --- chunk_codegen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e8cf0d22f157..9147aa9fcc20 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -49,7 +49,9 @@ def _is_non_compute_node(node): def _is_non_compute_node_except_placeholder(node): - if (any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"])): + if any(i in node.op for i in ["get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): return True return False From 31a2c5d09fb5496c90f740b3e7cac787ef489e91 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 17:24:06 +0800 Subject: [PATCH 031/209] work with outerproductmean and msa --- chunk_codegen.py | 258 ++++++++++++++++++++++++++++++----------------- 1 file changed, 168 insertions(+), 90 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 9147aa9fcc20..191eab564853 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -134,7 +134,7 @@ def find_node_flow(self, node): return name, i raise RuntimeError("invalid node") - def get_flow_mix(self, node): + def _get_flow_mix_node(self, node): if self._is_non_compute_node(node): return None _, node_trace = self.find_node_flow(node) @@ -145,7 +145,7 @@ def get_flow_mix(self, node): vars = list(node_trace["outside_depend"][0].values())[0] return vars - def get_same_flow_node(self, node_list, node): + def _get_same_flow_node(self, node_list, node): name, _ = self.find_node_flow(node) result = [] for i in self.flow_trace[name]: @@ -181,13 +181,14 @@ def trace_flow(self): ) return self.flow_trace - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = _find_chunk_input_and_output_nodes( + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) chunk_info = { "region": (start_idx, end_idx), "inputs": inputs, + "inputs_non_chunk": [], "inputs_dim": start_dim, "outputs": outputs, "outputs_dim": end_dim, @@ -197,31 +198,71 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): for idx in range(start_idx, end_idx + 1): node = self.node_list[idx] - mix_flow_var = self.get_flow_mix(node) - if mix_flow_var is None: + mix_flow_node = self._get_flow_mix_node(node) + if mix_flow_node is None: continue - # if there is a flow mix, op must be in [mul, add, div, matmul] + # if there is a flow mix, op must be in [mul, add, matmul] # element-wise op requires dim to be equal in every dim if any(n in node.name for n in ["mul", "add"]): for i in node.args: - if type(i) == type(mix_flow_var) and i != mix_flow_var: + if type(i) == type(mix_flow_node) and i != mix_flow_node: main_flow_var = i # if mix flow is a broadcast in chunk dim, # TODO need to move that flow out of the chunk - if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1: + mix_flow_node_dim = index_tracer._get_node_chunk_dim( + self.node_list[end_idx], end_dim, node + ) + if mix_flow_node_dim is None: flow_flag = True - for i in self.get_same_flow_node( - chunk_info["inputs"], mix_flow_var + break + if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + flow_flag = False + for i in self._get_same_flow_node( + chunk_info["inputs"], mix_flow_node ): chunk_info["inputs"].remove(i) # else, we need to chunk mix var as well else: # TODO chunk another value - flow_flag = False + flow_flag = True break else: raise NotImplementedError("%s not implemented" % node.name) + + inputs_dim = [] + remove_inputs = [] + for input_node in chunk_info['inputs']: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + dim = None + if start_dim <= user_idx < end_idx: + dim = index_tracer._get_node_chunk_dim( + self.node_list[end_idx], end_dim, input_node + ) + elif user_idx == end_idx: + dim = end_dim + # n has relation with chunk dim + if dim is not None and _get_node_shape(user)[dim] != 1: + input_dict[user_idx] = dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + chunk_info['inputs_dim'] = inputs_dim + for i in remove_inputs: + if i in chunk_info['inputs']: + chunk_info['inputs'].remove(i) + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1]) + for i in non_chunk_inputs: + if i not in chunk_info['inputs']: + chunk_info["inputs_non_chunk"].append(i) + return flow_flag, chunk_info @@ -367,6 +408,20 @@ def _find_trace_from_node(self, node): node_dict = self.idx_trace_list[node_idx] return node_dict + def _find_source_trace_from_node(self, node): + """ + Find node source trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict["source"] + def _find_idx_trace_from_node(self, node): """ Find node idx trace by the node. @@ -836,6 +891,15 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx): # return False # return True + def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + class MemoryEstimator(object): def __init__(self) -> None: @@ -931,8 +995,10 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): return mem def _get_chunk_ratio(self, node, chunk_dim, chunk_size): + sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0]) + dim = list(sorted_dim[-1].values())[0] shape = node.meta["tensor_meta"].shape - chunk_ratio = float(chunk_size) / shape[chunk_dim] + chunk_ratio = float(chunk_size) / shape[dim] return chunk_ratio def _get_chunk_delete_node_size( @@ -1157,6 +1223,8 @@ def _check_duplicate_map(self, chunk_infos): return chunk_infos def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + if start_idx == 71 and end_idx == 126: + print(1) start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] end_node = self.node_list[end_idx] @@ -1188,7 +1256,7 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): continue # detect flow meet flow_flag, chunk_info = self.flow_tracer._detect_flow( - start_idx, start_dim, end_idx, end_dim + start_idx, start_dim, end_idx, end_dim, self.index_tracer ) if flow_flag: continue @@ -1301,56 +1369,53 @@ def _get_first_non_single_dim(shape): raise RuntimeError("can not get first non single dim for shape", shape) -def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2): - if len(chunk_input_meta) == 1: - node = chunk_input_meta[0] - node_shape = node.meta["tensor_meta"].shape - free_shape = [ - node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) - ] - chunk_dim = _get_first_non_single_dim(free_shape) - chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) - out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) - - context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" - % (out_shape, node.name, node.name, chunk_size) - ) - context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) - context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) - else: - raise NotImplementedError( - "input with size %d not implemented" % len(chunk_input_meta) - ) +def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): + input_node = chunk_input[0] + + out_shape = _get_node_shape(chunk_output) + out_str = str(list(out_shape)) + + context = ( + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" + % (out_str, input_node.name, input_node.name, chunk_size) + ) + context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) + + # node = chunk_input[0] + # node_shape = node.meta["tensor_meta"].shape + # free_shape = [ + # node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) + # ] + # chunk_dim = _get_first_non_single_dim(free_shape) + # chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) + # out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) + + # context = ( + # "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" + # % (out_shape, node.name, node.name, chunk_size) + # ) + # context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) + # context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) return context -def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim): - chunk_inputs_name = chunk_inputs[0].name +def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list): chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - free_shape = [ - chunk_output_shape[i] if i in chunk_dim else 1 - for i in range(len(chunk_output_shape)) - ] - chunk_dim = _get_first_non_single_dim(free_shape) - chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) - - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) + context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") # determine if its the last use for chunk input - users_name = list(chunk_inputs[0].users.keys()) - if all( - [ - _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in users_name - ] - ): - context += "; %s = None" % chunk_inputs_name + for chunk_input in (chunk_inputs + chunk_non_compute_inputs): + if all( + [ + _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx + for user in chunk_input.users.keys() + ] + ): + context += "; %s = None" % chunk_input.name context += "\n" return context @@ -1382,7 +1447,24 @@ def _find_input_and_output_nodes(nodes: List[Node]): return input_nodes, output_nodes -def _find_chunk_input_and_output_nodes(nodes: List[Node]): +def _find_chunk_all_input_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + for node in nodes: + for input_node in node._input_nodes.keys(): + if ( + input_node not in nodes + and input_node not in input_nodes + ): + input_nodes.append(input_node) + return input_nodes + + +def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): """ Find non-compute input and output node names. input nodes are nodes used in the list @@ -1410,7 +1492,7 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]): if ( output_node not in nodes and node not in output_nodes - and not _is_non_compute_node_except_placeholder(input_node) + and not _is_non_compute_node_except_placeholder(output_node) ): output_nodes.append(node) @@ -1454,44 +1536,34 @@ def emit_code_with_chunk( emit_node_func: function to emit node delete_unused_value_func: function to remove the unused value """ + node_list = list(nodes) - # find the offload regions + # find the chunk regions chunk_region_search = ChunkRegionSearch(meta_graph) chunk_search = chunk_region_search.search_region() - chunk_regions = [i["region"] for i in chunk_search] - chunk_dims = [i["dim"] for i in chunk_search] - chunk_infos = [i["chunk_info"] for i in chunk_search] - - chunk_starts = [item[0] for item in chunk_regions] - chunk_ends = [item[1] for item in chunk_regions] - chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos] - chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos] - within_chunk_region = False - - node_list = list(nodes) - # find the input and output var names for each offload region - # for idx, (start, end) in enumerate(chunk_regions): - # offload_node_list = node_list[start:end + 1] - # inputs, outputs = _find_input_and_output_nodes(offload_node_list) - # chunk_inputs.append(inputs) - # chunk_outputs.append(outputs) + chunk_regions = [i["region"] for i in chunk_search] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_search] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] chunk_inputs_idx = [ [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs ] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] + + chunk_outputs = [i["outputs"][0] for i in chunk_search] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] chunk_outputs_idx = [ - [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs + _find_idx_by_name(i.name, node_list) for i in chunk_outputs ] - chunk_inputs_names = [] - for i in chunk_inputs: - for j in i: - chunk_inputs_names.append(j.name) - # this flag is to prevent repeated insert of save tensors - # hooks definition in ckpt_func node_idx = 0 region_idx = 0 + within_chunk_region = False + while node_idx < len(node_list): node = node_list[node_idx] @@ -1500,21 +1572,24 @@ def emit_code_with_chunk( region_idx = chunk_starts.index(node_idx) # add for loop - chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] body.append( _gen_loop_start( - chunk_input_meta, - node_list[chunk_ends[region_idx]], - chunk_dims[region_idx], + chunk_inputs[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], ) ) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body[-1] = _replace_name( - body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor" - ) + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node)) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) @@ -1526,7 +1601,10 @@ def emit_code_with_chunk( if node_idx in chunk_ends: body.append( _gen_loop_end( - node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx] + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], node_list ) ) within_chunk_region = False From b7b67c32ad79c4e81775b32fc4a36ec733915f56 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 17:25:38 +0800 Subject: [PATCH 032/209] code style --- chunk_codegen.py | 70 +++++++++++++++++++----------------------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 191eab564853..3bea84faeabb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -229,10 +229,10 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): break else: raise NotImplementedError("%s not implemented" % node.name) - + inputs_dim = [] remove_inputs = [] - for input_node in chunk_info['inputs']: + for input_node in chunk_info["inputs"]: input_dict = {} for user in input_node.users.keys(): if _is_non_compute_node(user): @@ -252,15 +252,17 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): remove_inputs.append(input_node) else: inputs_dim.append(input_dict) - chunk_info['inputs_dim'] = inputs_dim + chunk_info["inputs_dim"] = inputs_dim for i in remove_inputs: - if i in chunk_info['inputs']: - chunk_info['inputs'].remove(i) - + if i in chunk_info["inputs"]: + chunk_info["inputs"].remove(i) + # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1]) + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) for i in non_chunk_inputs: - if i not in chunk_info['inputs']: + if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) return flow_flag, chunk_info @@ -1371,44 +1373,32 @@ def _get_first_non_single_dim(shape): def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): input_node = chunk_input[0] - out_shape = _get_node_shape(chunk_output) out_str = str(list(out_shape)) - context = ( "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % (out_str, input_node.name, input_node.name, chunk_size) ) context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) - - # node = chunk_input[0] - # node_shape = node.meta["tensor_meta"].shape - # free_shape = [ - # node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) - # ] - # chunk_dim = _get_first_non_single_dim(free_shape) - # chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) - # out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) - - # context = ( - # "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" - # % (out_shape, node.name, node.name, chunk_size) - # ) - # context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) - # context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) return context -def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list): +def _gen_loop_end( + chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list +): chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim( + chunk_outputs_dim, "chunk_idx", chunk_output_shape + ) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) - context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") + context += ( + chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + ) # determine if its the last use for chunk input - for chunk_input in (chunk_inputs + chunk_non_compute_inputs): + for chunk_input in chunk_inputs + chunk_non_compute_inputs: if all( [ _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx @@ -1456,10 +1446,7 @@ def _find_chunk_all_input_nodes(nodes: List[Node]): input_nodes = [] for node in nodes: for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - ): + if input_node not in nodes and input_node not in input_nodes: input_nodes.append(input_node) return input_nodes @@ -1549,16 +1536,12 @@ def emit_code_with_chunk( chunk_inputs = [i["inputs"] for i in chunk_search] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] - chunk_inputs_idx = [ - [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i ] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] - chunk_outputs_idx = [ - _find_idx_by_name(i.name, node_list) for i in chunk_outputs - ] node_idx = 0 region_idx = 0 @@ -1586,7 +1569,9 @@ def emit_code_with_chunk( for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node)) + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) @@ -1604,7 +1589,8 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], node_list + chunk_outputs_dim[region_idx], + node_list, ) ) within_chunk_region = False From 5cdfcfe1d168e39d39a741112c036fa1455f0d06 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 17:29:07 +0800 Subject: [PATCH 033/209] code style --- chunk_codegen.py | 49 ++++-------------------------------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 3bea84faeabb..96dcbfc0f79d 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -92,24 +92,10 @@ def _init_trace(self): self._add_trace(i.name) self._add_node(i.name, i) - def _is_non_compute_node(self, node): - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - def _is_non_compute_node_except_placeholder(self, node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - def _find_flow_for_node(self, node): if type(self.node_list[0]) != type(node): return None - if self._is_non_compute_node_except_placeholder(node): + if _is_non_compute_node_except_placeholder(node): return None for name, trace in self.flow_trace.items(): for i in trace: @@ -135,7 +121,7 @@ def find_node_flow(self, node): raise RuntimeError("invalid node") def _get_flow_mix_node(self, node): - if self._is_non_compute_node(node): + if _is_non_compute_node(node): return None _, node_trace = self.find_node_flow(node) if len(node_trace["outside_depend"]) == 0: @@ -160,10 +146,9 @@ def trace_flow(self): for node in self.node_list: # skip if non compute node if all( - type(arg) != type(node) - or self._is_non_compute_node_except_placeholder(arg) + type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) for arg in node.args - ) or self._is_non_compute_node(node): + ) or _is_non_compute_node(node): continue node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] @@ -1411,32 +1396,6 @@ def _gen_loop_end( return context -def _find_input_and_output_nodes(nodes: List[Node]): - """ - Find the input and output node names which are not found in the given list of nodes. - """ - input_nodes = [] - output_nodes = [] - - # if a node has an input node which is not in the node list - # we treat that input node as the input of the checkpoint function - for node in nodes: - for input_node in node._input_nodes.keys(): - node_repr = repr(input_node) - if input_node not in nodes and input_node not in input_nodes: - input_nodes.append(input_node) - - # if a node has a user node which is not in the node list - # we treat that user node as the node receiving the current node output - for node in nodes: - for output_node in node.users.keys(): - node_repr = repr(node) - if output_node not in nodes and output_node not in output_nodes: - output_nodes.append(output_node) - - return input_nodes, output_nodes - - def _find_chunk_all_input_nodes(nodes: List[Node]): """ Find non-compute input and output node names. From 8511d900a88638cb04ced2db35b171a96f6f310c Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 17:36:17 +0800 Subject: [PATCH 034/209] code style --- chunk_codegen.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 96dcbfc0f79d..88d9178091b7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1210,8 +1210,6 @@ def _check_duplicate_map(self, chunk_infos): return chunk_infos def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): - if start_idx == 71 and end_idx == 126: - print(1) start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] end_node = self.node_list[end_idx] @@ -1347,15 +1345,6 @@ def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): return new_shape -def _get_first_non_single_dim(shape): - for idx, i in enumerate(shape): - if i == 1: - continue - else: - return idx - raise RuntimeError("can not get first non single dim for shape", shape) - - def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): input_node = chunk_input[0] out_shape = _get_node_shape(chunk_output) From 98f9728e29f463692cea1533c998f0e7f2381e59 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 18:15:47 +0800 Subject: [PATCH 035/209] code style --- chunk_codegen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 88d9178091b7..22d48f5d661a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -194,7 +194,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): if type(i) == type(mix_flow_node) and i != mix_flow_node: main_flow_var = i # if mix flow is a broadcast in chunk dim, - # TODO need to move that flow out of the chunk + # TODO: need to move that flow out of the chunk mix_flow_node_dim = index_tracer._get_node_chunk_dim( self.node_list[end_idx], end_dim, node ) @@ -1200,7 +1200,7 @@ def _check_duplicate_map(self, chunk_infos): continue # it means an index create 2 copy of itself # eg. a = torch.matmul(x, x.transpose(-1, -2)) - # TODO currently remove it, deal with this in future + # TODO: currently remove it, deal with this in future if input_dim1 == input_dim2 and output_dim1 != output_dim2: remove_list.append(chunk_infos[idx1]) remove_list.append(chunk_infos[idx2]) @@ -1216,7 +1216,7 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): chunk_infos = [] for end_dim, end_trace_idx in enumerate(end_trace["idx"]): if len(start_traces) > 1: - # TODO implement multi input chunk + # TODO: implement multi input chunk continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): @@ -1421,7 +1421,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output - # TODO it is unsafe to remove non compute node here + # TODO: it is unsafe to remove non compute node here for node in nodes: for output_node in node.users.keys(): if ( From 8754fa255376055c01aab4a3fab385454b8b7930 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 18:25:47 +0800 Subject: [PATCH 036/209] change threshold --- chunk_codegen_run.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 88c734903392..99700e1af9d8 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -45,8 +45,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output" - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output" + + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output" # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() From 1e0fd11bc1773ca47cbd95fb19b86517265390ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 13 Dec 2022 10:01:30 +0800 Subject: [PATCH 037/209] support check_index_duplicate --- chunk_codegen.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 22d48f5d661a..64bff4a801a1 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -179,7 +179,12 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): "outputs_dim": end_dim, "args": {}, } - flow_flag = False + flow_block = False + + # TODO don't allow multi outputs now + if len(outputs) > 1: + flow_block = True + return flow_block, chunk_info for idx in range(start_idx, end_idx + 1): node = self.node_list[idx] @@ -199,10 +204,10 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): self.node_list[end_idx], end_dim, node ) if mix_flow_node_dim is None: - flow_flag = True + flow_block = True break if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_flag = False + flow_block = False for i in self._get_same_flow_node( chunk_info["inputs"], mix_flow_node ): @@ -210,11 +215,15 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): # else, we need to chunk mix var as well else: # TODO chunk another value - flow_flag = True + flow_block = True break else: raise NotImplementedError("%s not implemented" % node.name) + if flow_block: + flow_block = True + return flow_block, chunk_info + inputs_dim = [] remove_inputs = [] for input_node in chunk_info["inputs"]: @@ -250,7 +259,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) - return flow_flag, chunk_info + return flow_block, chunk_info class IndexTracer(object): @@ -869,14 +878,6 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx): if any(start_idx <= i <= end_idx for i in end_node_compute): return False return True - # end_node_trace_source = end_node_trace['source'][end_dim] - # for node_idx, node_dim in end_node_trace_source.items(): - # if node_idx < start_node_idx or node_idx > end_node_idx: - # continue - # compute_list = self.idx_trace_list[node_idx]['compute'][node_dim] - # if any(start_node_idx <= i <= end_node_idx for i in compute_list): - # return False - # return True def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) @@ -1240,10 +1241,10 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): ): continue # detect flow meet - flow_flag, chunk_info = self.flow_tracer._detect_flow( + flow_block, chunk_info = self.flow_tracer._detect_flow( start_idx, start_dim, end_idx, end_dim, self.index_tracer ) - if flow_flag: + if flow_block: continue chunk_infos.append(chunk_info) chunk_infos = self._check_duplicate_map(chunk_infos) From cda3e8572a8ab1f0c48342ad305fadbf892d62b2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 13 Dec 2022 10:02:26 +0800 Subject: [PATCH 038/209] support index dupilictae and update loop --- chunk_codegen.py | 109 +++++++++++++++++++++++++++++-------------- chunk_codegen_run.py | 4 +- 2 files changed, 76 insertions(+), 37 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 64bff4a801a1..b5bb8f18560a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -180,7 +180,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): "args": {}, } flow_block = False - + # TODO don't allow multi outputs now if len(outputs) > 1: flow_block = True @@ -200,7 +200,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): main_flow_var = i # if mix flow is a broadcast in chunk dim, # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer._get_node_chunk_dim( + mix_flow_node_dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, node ) if mix_flow_node_dim is None: @@ -223,7 +223,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): if flow_block: flow_block = True return flow_block, chunk_info - + inputs_dim = [] remove_inputs = [] for input_node in chunk_info["inputs"]: @@ -234,7 +234,7 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): user_idx = _find_idx_by_name(user.name, self.node_list) dim = None if start_dim <= user_idx < end_idx: - dim = index_tracer._get_node_chunk_dim( + dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, input_node ) elif user_idx == end_idx: @@ -300,10 +300,10 @@ def _del_dim(self, idx, dim_idx): self.idx_trace_list[idx]["compute"].pop(dim_idx) self.idx_trace_list[idx]["source"].pop(dim_idx) - def _add_dim(self, idx, dim_idx): - self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index()) - self.idx_trace_list[idx]["compute"].insert(dim_idx, []) - self.idx_trace_list[idx]["source"].insert(dim_idx, {}) + def _add_dim(self, node_idx, dim_idx): + self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) + self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) + self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) def _transform_index(self, node, node_dim): node_idx = self._find_idx_trace_from_node(node) @@ -659,9 +659,7 @@ def _assign_unsqueeze_index(self, node, node_idx): """ self._del_dim(node_idx, -1) self._assign_index_as_input(node, node_idx) - self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index()) - self.idx_trace_list[node_idx]["compute"].insert(node.args[1], []) - self.idx_trace_list[node_idx]["source"].insert(node.args[1], []) + self._add_dim(node_idx, node.args[1]) def _assign_dropout_index(self, node, node_idx): """ @@ -879,7 +877,7 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx): return False return True - def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) dim_source = node_from_source[node_from_dim] node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) @@ -888,6 +886,44 @@ def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): return v return None + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(_get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and node_trace_source[node_dim][input_node_idx] == input_dim + ): + return {node_idx: node_dim} + return {} + + def check_index_duplicate(self, chunk_infos): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + input_dim_after_node.update( + self._find_inherit_dim(input_node, v, self.nodes_list[k]) + ) + + for node in self.nodes_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if _is_non_compute_node_except_placeholder(node): + continue + count = 0 + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(_get_node_shape(node))): + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] == v: + count += 1 + break + if count > 1: + return False + return True + class MemoryEstimator(object): def __init__(self) -> None: @@ -1160,7 +1196,7 @@ def _get_min_free_var(self, active_node_list, free_vars): min_len = len(n) return min_len - def _search_max_chunk_region(self, active_node, peak_node): + def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): free_vars = self._get_free_var() min_var = self._get_min_free_var(active_node, free_vars) @@ -1180,6 +1216,21 @@ def _search_max_chunk_region(self, active_node, peak_node): break if i in free_vars or i == 0: raise RuntimeError() + + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif ( + region[0] <= chunk_region_start <= region[1] + and chunk_region_end > region[1] + ): + chunk_region_start = region[1] + 1 + elif ( + region[0] <= chunk_region_end <= region[1] + and chunk_region_start < region[0] + ): + chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end def _is_not_compute(self, trace, chunk_range, dim_idx): @@ -1192,24 +1243,6 @@ def _is_not_compute(self, trace, chunk_range, dim_idx): return True return False - def _check_duplicate_map(self, chunk_infos): - dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos] - remove_list = [] - for idx1, (input_dim1, output_dim1) in enumerate(dim_map): - for idx2, (input_dim2, output_dim2) in enumerate(dim_map): - if idx1 == idx2: - continue - # it means an index create 2 copy of itself - # eg. a = torch.matmul(x, x.transpose(-1, -2)) - # TODO: currently remove it, deal with this in future - if input_dim1 == input_dim2 and output_dim1 != output_dim2: - remove_list.append(chunk_infos[idx1]) - remove_list.append(chunk_infos[idx2]) - for i in remove_list: - if i in chunk_infos: - chunk_infos.remove(i) - return chunk_infos - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] @@ -1246,8 +1279,10 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): ) if flow_block: continue + # check index copmute + if not self.index_tracer.check_index_duplicate(chunk_info): + continue chunk_infos.append(chunk_info) - chunk_infos = self._check_duplicate_map(chunk_infos) return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): @@ -1288,9 +1323,13 @@ def _search_best_chunk_region(self, possible_chunk_regions): max_region_range = i["region"][1] - i["region"][0] return best_regions - def _step_search(self, mem_peak, active_node): + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + max_chunk_region = self._search_max_chunk_region( + active_node, peak_node, chunk_regions + ) + if max_chunk_region == None: + return None possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) @@ -1313,7 +1352,7 @@ def search_region(self): mem_peak = init_mem_peak while True: - chunk_region = self._step_search(mem_peak, active_node) + chunk_region = self._step_search(mem_peak, active_node, chunk_regions) if chunk_region is None: break diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 99700e1af9d8..ae4653d6545b 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -46,8 +46,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output" - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1])) # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() From de65e6c3e88bc1b217b894bf20a4769748145605 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 13 Dec 2022 11:00:51 +0800 Subject: [PATCH 039/209] support output --- chunk_codegen.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index b5bb8f18560a..79cefddf07d2 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -56,6 +56,14 @@ def _is_non_compute_node_except_placeholder(node): return False +def _is_non_compute_node_except_placeholder_output(node): + if any(i in node.op for i in ["get_attr"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -1083,13 +1091,14 @@ def estimate_chunk_inference_mem( i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes] ) chunk_within = False - chunk_region_idx = 0 + chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in start_nodes: chunk_within = True + chunk_region_idx = start_nodes.index(idx) chunk_ratio = self._get_chunk_ratio( node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx] ) @@ -1149,7 +1158,7 @@ def estimate_chunk_inference_mem( ) chunk_within = False chunk_ratio = 1 - chunk_region_idx += 1 + chunk_region_idx = None act_memory_after_node_log.append(act_memory) active_node_list_log.append(copy.deepcopy(active_node_list)) @@ -1467,7 +1476,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): if ( output_node not in nodes and node not in output_nodes - and not _is_non_compute_node_except_placeholder(output_node) + and not _is_non_compute_node_except_placeholder_output(output_node) ): output_nodes.append(node) From e83e3c615452c5f8ab04f558880c378256d95802 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 16 Dec 2022 11:09:35 +0800 Subject: [PATCH 040/209] update memory estimate --- chunk_codegen.py | 177 +++++++++++++++++++++++++++++------------------ 1 file changed, 111 insertions(+), 66 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 79cefddf07d2..18d9a0c8d764 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -896,23 +896,22 @@ def get_node_chunk_dim(self, node_from, node_from_dim, node_to): def _find_inherit_dim(self, input_node, input_dim, node): input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) - node_idx = _find_idx_by_name(node.name, self.nodes_list) node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] and node_trace_source[node_dim][input_node_idx] == input_dim ): - return {node_idx: node_dim} - return {} + return node_dim + return None def check_index_duplicate(self, chunk_infos): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - input_dim_after_node.update( - self._find_inherit_dim(input_node, v, self.nodes_list[k]) - ) + inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim for node in self.nodes_list[ chunk_infos["region"][0] : chunk_infos["region"][1] + 1 @@ -934,8 +933,8 @@ def check_index_duplicate(self, chunk_infos): class MemoryEstimator(object): - def __init__(self) -> None: - pass + def __init__(self, index_tracer: IndexTracer) -> None: + self.index_tracer = index_tracer def _get_meta_node_size(self, x): x = x.meta["tensor_meta"] @@ -950,6 +949,8 @@ def _get_output_node(self, n): } out_size = activation_size(fwd_out) out_node = [n.name] if out_size > 0 else [] + # if any(i in n.name for i in ['transpose', 'permute', 'view']): + # out_size = 0 return out_size, out_node def _get_output_node_size(self, n): @@ -961,11 +962,19 @@ def _add_active_node(self, n, active_list): if i not in active_list: active_list.append(i) - def _get_delete_node(self, user, user_to_last_uses): + def _get_delete_node(self, user, user_to_last_uses, to_keep=None): delete_size = 0 delete_node = [] if user.op not in ("placeholder", "output"): nodes_to_delete = user_to_last_uses.get(user, []) + if to_keep is not None: + keep_list = [] + for n in nodes_to_delete: + if n.name in to_keep: + keep_list.append(n) + for n in keep_list: + if n in nodes_to_delete: + nodes_to_delete.remove(n) if len(nodes_to_delete): out_node = [self._get_output_node(i) for i in nodes_to_delete] delete_size = sum([i[0] for i in out_node]) @@ -974,15 +983,30 @@ def _get_delete_node(self, user, user_to_last_uses): delete_node.append(out_node[i][1][0]) elif nodes_to_delete[i].op == "placeholder": delete_node.append(nodes_to_delete[i].name) + # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): + # delete_node.append(nodes_to_delete[i].name) return delete_size, delete_node - def _get_delete_node_size(self, user, user_to_last_uses): - return self._get_delete_node(user, user_to_last_uses)[0] + def _get_delete_node_size(self, user, user_to_last_uses, to_keep): + return self._get_delete_node(user, user_to_last_uses, to_keep)[0] def _remove_deactive_node(self, user, user_to_last_uses, active_list): delete_node = self._get_delete_node(user, user_to_last_uses)[1] for i in delete_node: - active_list.remove(i) + if i in active_list: + active_list.remove(i) + + def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): + nodes_to_delete = [] + for chunk_input in chunk_inputs + chunk_inputs_non_chunk: + chunk_input_users = chunk_input.users.keys() + chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users] + if all(i <= chunk_end_idx for i in chunk_input_users_idx): + if chunk_input not in nodes_to_delete: + nodes_to_delete.append(chunk_input) + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + return delete_size def _get_last_usr(self, nodes): node_to_last_use: Dict[Node, Node] = {} @@ -1000,7 +1024,8 @@ def register_last_uses(n: Node, user: Node): def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): mem = 0 - not_contiguous_ops = ["transpose", "permute"] + not_contiguous_ops = ["permute"] + inherit_contiguous_ops = ["transpose", "view"] if node.op == "call_function" and any( n in node.name for n in ["matmul", "reshape"] @@ -1020,30 +1045,36 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): ): if node not in not_contiguous_list: not_contiguous_list.append(node) - elif any(i in node.args for i in not_contiguous_list): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - return mem - def _get_chunk_ratio(self, node, chunk_dim, chunk_size): - sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0]) - dim = list(sorted_dim[-1].values())[0] - shape = node.meta["tensor_meta"].shape - chunk_ratio = float(chunk_size) / shape[dim] - return chunk_ratio + def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): + node_shape = _get_node_shape(node) + node_source = self.index_tracer._find_source_trace_from_node(node) + for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): + for k, v in input_node_dim.items(): + inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) + if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): + chunk_ratio = float(chunk_size) / node_shape[inherit_dim] + return chunk_ratio + for dim, source in enumerate(node_source): + if k in source and source[k] == inherit_dim: + chunk_ratio = float(chunk_size) / node_shape[dim] + return chunk_ratio + return 1. def _get_chunk_delete_node_size( - self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node + self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names ): + # if any(j in user.name for j in ['transpose', 'permute', 'view']): + # return 0 if user.op in ("placeholder", "output"): return 0 nodes_to_delete = user_to_last_uses.get(user, []) delete_size = 0 for n in nodes_to_delete: - node_idx = _find_idx_by_name(n.name, node_list) - if start_node <= node_idx < end_node: - delete_size += self._get_output_node_size(n) * chunk_ratio + if n.name in chunk_inputs_names: + continue + delete_size += self._get_output_node_size(n) * chunk_ratio return delete_size def _print_mem_log(self, log, nodes, title=None): @@ -1071,10 +1102,7 @@ def _print_compute_op_mem_log(self, log, nodes, title=None): def estimate_chunk_inference_mem( self, gm: torch.fx.GraphModule, - start_nodes=None, - end_nodes=None, - chunk_dims=None, - chunk_sizes=None, + chunk_infos=None, ): act_memory = 0.0 act_memory_peak_log = [] @@ -1087,36 +1115,53 @@ def estimate_chunk_inference_mem( user_to_last_uses_no_free_var = self._get_last_usr(node_list) _delete_free_var_from_last_use(user_to_last_uses_no_free_var) - use_chunk = all( - i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes] - ) + use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem + chunk_size = 1 + chunk_inputs_names = [] + + if use_chunk: + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + chunk_outputs = [i["outputs"][0] for i in chunk_infos] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if use_chunk and idx in start_nodes: + if use_chunk and idx in chunk_starts: chunk_within = True - chunk_region_idx = start_nodes.index(idx) + chunk_region_idx = chunk_starts.index(idx) + act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) + + # determine chunk ratio for current node + if chunk_within: chunk_ratio = self._get_chunk_ratio( - node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx] + node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size ) - act_memory += self._get_output_node_size( - node_list[end_nodes[chunk_region_idx]] - ) / (1024**2) # if node is placeholder, just add the size of the node if node.op == "placeholder": act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) act_memory_peak_log.append(act_memory) - active_node_list.append(node.name) # skip output elif node.op == "output": continue - # node is an operation, calculate tmp, output node and delete node memory + # no change for non compute node + elif _is_non_compute_node_except_placeholder(node): + act_memory_peak_log.append(act_memory) + # node is a compute op + # calculate tmp, output node and delete node memory else: # forward memory + # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose act_memory += ( self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio @@ -1133,29 +1178,35 @@ def estimate_chunk_inference_mem( * chunk_ratio / (1024**2) ) + # delete unused vars not in chunk_input_list + # we can't delete input nodes until chunk ends if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses_no_free_var, chunk_ratio, - node_list, - start_nodes[chunk_region_idx], - end_nodes[chunk_region_idx], + chunk_inputs_names ) / (1024**2) else: - act_memory -= self._get_delete_node_size( - node, user_to_last_uses_no_free_var - ) / (1024**2) + act_memory -= (self._get_delete_node_size( + node, user_to_last_uses_no_free_var, chunk_inputs_names + ) / (1024**2)) - # log active node + # log active node, only effective without chunk self._add_active_node(node, active_node_list) self._remove_deactive_node(node, user_to_last_uses, active_node_list) # if node in chunk end nodes, restore chunk settings - if use_chunk and idx in end_nodes: + if use_chunk and idx in chunk_ends: act_memory -= ( self._get_output_node_size(node) * chunk_ratio / (1024**2) ) + act_memory -= self._get_chunk_inputs_size( + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], + node_list, + chunk_regions[chunk_region_idx][1] + ) / (1024**2) chunk_within = False chunk_ratio = 1 chunk_region_idx = None @@ -1178,11 +1229,11 @@ class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.node_list = list(gm.graph.nodes) - self.memory_estimator = MemoryEstimator() self.index_tracer = IndexTracer(gm) self.index_tracer.trace_index() self.flow_tracer = FlowTracer(gm) self.flow_tracer.trace_flow() + self.memory_estimator = MemoryEstimator(self.index_tracer) def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1210,7 +1261,7 @@ def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): min_var = self._get_min_free_var(active_node, free_vars) # from peak_node to free_var - chunk_region_start = None + chunk_region_start = len(free_vars) for i in range(peak_node, -1, -1): if len(active_node[i]) == min_var: chunk_region_start = i + 1 @@ -1218,7 +1269,7 @@ def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): if i in free_vars or i == 0: raise RuntimeError() # from peak_node to len-2 - chunk_region_end = None + chunk_region_end = len(active_node) - 1 for i in range(peak_node, len(active_node)): if len(active_node[i]) == min_var: chunk_region_end = i @@ -1352,7 +1403,7 @@ def _stop_search(self, init_mem_peak, mem_peak): return False def search_region(self): - chunk_regions = [] + chunk_infos = [] ( init_mem_peak, _, @@ -1361,25 +1412,19 @@ def search_region(self): mem_peak = init_mem_peak while True: - chunk_region = self._step_search(mem_peak, active_node, chunk_regions) - if chunk_region is None: + chunk_info = self._step_search(mem_peak, active_node, chunk_infos) + if chunk_info is None: break - chunk_regions.append(chunk_region) + chunk_infos.append(chunk_info) ( mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.gm, - [i["region"][0] for i in chunk_regions], - [i["region"][1] for i in chunk_regions], - [i["inputs_dim"] for i in chunk_regions], - [1] * len(chunk_regions), - ) + ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos) if self._stop_search(init_mem_peak, mem_peak): break - return chunk_regions + return chunk_infos def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -1415,7 +1460,7 @@ def _gen_loop_end( chunk_slice = _gen_chunk_slice_dim( chunk_outputs_dim, "chunk_idx", chunk_output_shape ) - context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) + context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name) context += ( chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" ) From e66a18a0bfaa87767d5869ab21a76c48af8b81cf Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 16 Dec 2022 15:06:39 +0800 Subject: [PATCH 041/209] optimise search --- chunk_codegen.py | 67 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 18d9a0c8d764..5e2130ee76f4 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -958,6 +958,8 @@ def _get_output_node_size(self, n): def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] + if n.op == 'placeholder': + new_active.append(n.name) for i in new_active: if i not in active_list: active_list.append(i) @@ -965,7 +967,7 @@ def _add_active_node(self, n, active_list): def _get_delete_node(self, user, user_to_last_uses, to_keep=None): delete_size = 0 delete_node = [] - if user.op not in ("placeholder", "output"): + if user.op not in ("output",): nodes_to_delete = user_to_last_uses.get(user, []) if to_keep is not None: keep_list = [] @@ -1258,24 +1260,30 @@ def _get_min_free_var(self, active_node_list, free_vars): def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): free_vars = self._get_free_var() - min_var = self._get_min_free_var(active_node, free_vars) - + free_var_num = len(free_vars) + active_node_num = [len(i) for i in active_node] + min_active_node_num = min(active_node_num[free_var_num:]) + threshold = max(free_var_num, min_active_node_num) + # from peak_node to free_var - chunk_region_start = len(free_vars) + inside_flag = False + chunk_region_start = free_var_num for i in range(peak_node, -1, -1): - if len(active_node[i]) == min_var: + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: chunk_region_start = i + 1 break - if i in free_vars or i == 0: - raise RuntimeError() + # from peak_node to len-2 + inside_flag = False chunk_region_end = len(active_node) - 1 for i in range(peak_node, len(active_node)): - if len(active_node[i]) == min_var: + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: chunk_region_end = i break - if i in free_vars or i == 0: - raise RuntimeError() for i in chunk_regions: region = i["region"] @@ -1374,15 +1382,34 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region.extend(chunk_info) return possible_chunk_region - def _search_best_chunk_region(self, possible_chunk_regions): + def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 - best_regions = None - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_regions = i - max_region_range = i["region"][1] - i["region"][0] - return best_regions - + best_region = None + while len(possible_chunk_regions) > 0: + for i in possible_chunk_regions: + if i["region"][1] - i["region"][0] > max_region_range: + best_region = i + max_region_range = i["region"][1] - i["region"][0] + if self._is_legal_region(best_region, chunk_infos): + break + possible_chunk_regions.remove(i) + max_region_range = 0 + best_region = None + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0])): + return False + return True + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1393,7 +1420,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) + best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1919,5 +1946,5 @@ def emit_node(node: Node, body): {prologue} {code}""" - print(fn_code) + # print(fn_code) return PythonCode(fn_code, globals_) From 9d516fa68f4e029d63b53d78803667bfa71e86d6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sun, 18 Dec 2022 20:37:55 +0800 Subject: [PATCH 042/209] fix layernorm --- chunk_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 5e2130ee76f4..77c28fd32c88 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -574,7 +574,7 @@ def _assign_layernorm_index(self, node, idx): node_idx (int) """ self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [-1, -2]) + self._mark_computation(node, idx, [-1]) def _assign_elementwise_index(self, node, idx): """ From d734529a390087f1366b7573410eca5775735b14 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:00:24 +0800 Subject: [PATCH 043/209] move flow tracer --- chunk_codegen.py | 413 ++++++++++++++++++++++++----------------------- 1 file changed, 207 insertions(+), 206 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 77c28fd32c88..2c1c09ae5238 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -64,212 +64,6 @@ def _is_non_compute_node_except_placeholder_output(node): return False -class FlowTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) - self.flow_trace = {} - - def _add_trace(self, name): - self.flow_trace[name] = [] - - def _add_node(self, trace_name, node): - self.flow_trace[trace_name].append( - {"node": node, "inside_depend": [], "outside_depend": []} - ) - - def _add_inside_depend(self, flow_name, node, inside_depend_node): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["inside_depend"].append(inside_depend_node) - return - raise RuntimeError("node not found") - - def _add_outside_depend( - self, flow_name, node, outside_depend_node, outside_depend_trace - ): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["outside_depend"].append({outside_depend_trace: outside_depend_node}) - return - raise RuntimeError("node not found") - - def _init_trace(self): - for i in self.node_list: - if i.op == "placeholder": - self._add_trace(i.name) - self._add_node(i.name, i) - - def _find_flow_for_node(self, node): - if type(self.node_list[0]) != type(node): - return None - if _is_non_compute_node_except_placeholder(node): - return None - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name - if any(i in node.name for i in ["ones_like"]): - self._add_trace(node.name) - self._add_node(node.name, node) - return node.name - raise RuntimeError("node not found") - - def _find_first_valid_flow(self, flow): - for i in flow: - if i is not None: - return i - raise RuntimeError("invalid flow") - - def find_node_flow(self, node): - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name, i - raise RuntimeError("invalid node") - - def _get_flow_mix_node(self, node): - if _is_non_compute_node(node): - return None - _, node_trace = self.find_node_flow(node) - if len(node_trace["outside_depend"]) == 0: - return None - elif len(node_trace["outside_depend"]) > 1: - raise NotImplementedError - vars = list(node_trace["outside_depend"][0].values())[0] - return vars - - def _get_same_flow_node(self, node_list, node): - name, _ = self.find_node_flow(node) - result = [] - for i in self.flow_trace[name]: - if i["node"] in node_list: - result.append(i["node"]) - return result - - def trace_flow(self): - # init trace - self._init_trace() - - for node in self.node_list: - # skip if non compute node - if all( - type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) - for arg in node.args - ) or _is_non_compute_node(node): - continue - - node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] - - node_domin_flow = self._find_first_valid_flow(node_input_flows) - self._add_node(node_domin_flow, node) - for node_input_flow, arg in zip(node_input_flows, node.args): - if node_input_flow is None: - continue - elif node_input_flow == node_domin_flow: - self._add_inside_depend(node_domin_flow, node, arg) - else: - self._add_outside_depend( - node_domin_flow, node, arg, node_input_flow - ) - return self.flow_trace - - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": start_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "args": {}, - } - flow_block = False - - # TODO don't allow multi outputs now - if len(outputs) > 1: - flow_block = True - return flow_block, chunk_info - - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_node = self._get_flow_mix_node(node) - if mix_flow_node is None: - continue - - # if there is a flow mix, op must be in [mul, add, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ["mul", "add"]): - for i in node.args: - if type(i) == type(mix_flow_node) and i != mix_flow_node: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, node - ) - if mix_flow_node_dim is None: - flow_block = True - break - if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_block = False - for i in self._get_same_flow_node( - chunk_info["inputs"], mix_flow_node - ): - chunk_info["inputs"].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_block = True - break - else: - raise NotImplementedError("%s not implemented" % node.name) - - if flow_block: - flow_block = True - return flow_block, chunk_info - - inputs_dim = [] - remove_inputs = [] - for input_node in chunk_info["inputs"]: - input_dict = {} - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - dim = None - if start_dim <= user_idx < end_idx: - dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, input_node - ) - elif user_idx == end_idx: - dim = end_dim - # n has relation with chunk dim - if dim is not None and _get_node_shape(user)[dim] != 1: - input_dict[user_idx] = dim - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - chunk_info["inputs_dim"] = inputs_dim - for i in remove_inputs: - if i in chunk_info["inputs"]: - chunk_info["inputs"].remove(i) - - # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] - ) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - return flow_block, chunk_info - - class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -932,6 +726,213 @@ def check_index_duplicate(self, chunk_infos): return True + +class FlowTracer(object): + def __init__(self, gm) -> None: + self.gm = gm + self.node_list = list(gm.graph.nodes) + self.flow_trace = {} + + def _add_trace(self, name): + self.flow_trace[name] = [] + + def _add_node(self, trace_name, node): + self.flow_trace[trace_name].append( + {"node": node, "inside_depend": [], "outside_depend": []} + ) + + def _add_inside_depend(self, flow_name, node, inside_depend_node): + for i in self.flow_trace[flow_name]: + if i["node"] == node: + i["inside_depend"].append(inside_depend_node) + return + raise RuntimeError("node not found") + + def _add_outside_depend( + self, flow_name, node, outside_depend_node, outside_depend_trace + ): + for i in self.flow_trace[flow_name]: + if i["node"] == node: + i["outside_depend"].append({outside_depend_trace: outside_depend_node}) + return + raise RuntimeError("node not found") + + def _init_trace(self): + for i in self.node_list: + if i.op == "placeholder": + self._add_trace(i.name) + self._add_node(i.name, i) + + def _find_flow_for_node(self, node): + if type(self.node_list[0]) != type(node): + return None + if _is_non_compute_node_except_placeholder(node): + return None + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i["node"]: + return name + if any(i in node.name for i in ["ones_like"]): + self._add_trace(node.name) + self._add_node(node.name, node) + return node.name + raise RuntimeError("node not found") + + def _find_first_valid_flow(self, flow): + for i in flow: + if i is not None: + return i + raise RuntimeError("invalid flow") + + def find_node_flow(self, node): + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i["node"]: + return name, i + raise RuntimeError("invalid node") + + def _get_flow_mix_node(self, node): + if _is_non_compute_node(node): + return None + _, node_trace = self.find_node_flow(node) + if len(node_trace["outside_depend"]) == 0: + return None + elif len(node_trace["outside_depend"]) > 1: + raise NotImplementedError + vars = list(node_trace["outside_depend"][0].values())[0] + return vars + + def _get_same_flow_node(self, node_list, node): + name, _ = self.find_node_flow(node) + result = [] + for i in self.flow_trace[name]: + if i["node"] in node_list: + result.append(i["node"]) + return result + + def trace_flow(self): + # init trace + self._init_trace() + + for node in self.node_list: + # skip if non compute node + if all( + type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) + for arg in node.args + ) or _is_non_compute_node(node): + continue + + node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] + + node_domin_flow = self._find_first_valid_flow(node_input_flows) + self._add_node(node_domin_flow, node) + for node_input_flow, arg in zip(node_input_flows, node.args): + if node_input_flow is None: + continue + elif node_input_flow == node_domin_flow: + self._add_inside_depend(node_domin_flow, node, arg) + else: + self._add_outside_depend( + node_domin_flow, node, arg, node_input_flow + ) + return self.flow_trace + + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": start_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "args": {}, + } + flow_block = False + + # TODO don't allow multi outputs now + if len(outputs) > 1: + flow_block = True + return flow_block, chunk_info + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_node = self._get_flow_mix_node(node) + if mix_flow_node is None: + continue + + # if there is a flow mix, op must be in [mul, add, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ["mul", "add"]): + for i in node.args: + if type(i) == type(mix_flow_node) and i != mix_flow_node: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO: need to move that flow out of the chunk + mix_flow_node_dim = index_tracer.get_node_chunk_dim( + self.node_list[end_idx], end_dim, node + ) + if mix_flow_node_dim is None: + flow_block = True + break + if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + flow_block = False + for i in self._get_same_flow_node( + chunk_info["inputs"], mix_flow_node + ): + chunk_info["inputs"].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_block = True + break + else: + raise NotImplementedError("%s not implemented" % node.name) + + if flow_block: + flow_block = True + return flow_block, chunk_info + + inputs_dim = [] + remove_inputs = [] + for input_node in chunk_info["inputs"]: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + dim = None + if start_dim <= user_idx < end_idx: + dim = index_tracer.get_node_chunk_dim( + self.node_list[end_idx], end_dim, input_node + ) + elif user_idx == end_idx: + dim = end_dim + # n has relation with chunk dim + if dim is not None and _get_node_shape(user)[dim] != 1: + input_dict[user_idx] = dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + chunk_info["inputs_dim"] = inputs_dim + for i in remove_inputs: + if i in chunk_info["inputs"]: + chunk_info["inputs"].remove(i) + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + return flow_block, chunk_info + + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: self.index_tracer = index_tracer From d361d533e8e7773d2009cc4ff5a82633401ab44a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:01:03 +0800 Subject: [PATCH 044/209] refactor flow tracer --- chunk_codegen.py | 283 +++++++++++++++++++++++++++++++++-------- evoformer/evoformer.py | 11 +- 2 files changed, 240 insertions(+), 54 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 2c1c09ae5238..3ba082ceb845 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -139,7 +139,13 @@ def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) if init: node_to_trace["source"][node_to_dim] = {} - node_to_trace["source"][node_to_dim][node_from_idx] = node_from_dim + # add dim to cur new source + if node_from_idx not in node_to_trace["source"][node_to_dim]: + node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] + else: + if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: + node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim) + # update inputs source node_to_trace["source"][node_to_dim].update( node_from_trace["source"][node_from_dim] ) @@ -654,7 +660,7 @@ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node end_node_trace_source.items(), key=lambda d: d[0], reverse=True ) for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and node_dim == start_dim: + if node_idx == start_node_idx and start_dim in node_dim: return True # it means we meet a node outside the loop, and the node is not input node if node_idx < start_idx: @@ -694,12 +700,12 @@ def _find_inherit_dim(self, input_node, input_dim, node): for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] - and node_trace_source[node_dim][input_node_idx] == input_dim + and input_dim in node_trace_source[node_dim][input_node_idx] ): return node_dim return None - def check_index_duplicate(self, chunk_infos): + def check_index_duplicate(self, chunk_infos, return_dim=False): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): @@ -713,17 +719,30 @@ def check_index_duplicate(self, chunk_infos): if _is_non_compute_node_except_placeholder(node): continue count = 0 + duplicate_dims = [] node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False dim_source = node_trace_source[node_dim] for k, v in dim_source.items(): if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] == v: - count += 1 - break + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + if count > 1: - return False - return True + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True @@ -857,43 +876,45 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind flow_block = True return flow_block, chunk_info - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_node = self._get_flow_mix_node(node) - if mix_flow_node is None: - continue - - # if there is a flow mix, op must be in [mul, add, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ["mul", "add"]): - for i in node.args: - if type(i) == type(mix_flow_node) and i != mix_flow_node: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, node - ) - if mix_flow_node_dim is None: - flow_block = True - break - if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_block = False - for i in self._get_same_flow_node( - chunk_info["inputs"], mix_flow_node - ): - chunk_info["inputs"].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_block = True - break - else: - raise NotImplementedError("%s not implemented" % node.name) - - if flow_block: - flow_block = True - return flow_block, chunk_info + # for idx in range(start_idx, end_idx + 1): + # node = self.node_list[idx] + # mix_flow_node = self._get_flow_mix_node(node) + # if mix_flow_node is None: + # continue + + # # if there is a flow mix, op must be in [mul, add, matmul] + # # element-wise op requires dim to be equal in every dim + # if any(n in node.name for n in ["mul", "add"]): + # for i in node.args: + # if type(i) == type(mix_flow_node) and i != mix_flow_node: + # main_flow_var = i + # # if mix flow is a broadcast in chunk dim, + # # TODO: need to move that flow out of the chunk + # mix_flow_node_dim = index_tracer.get_node_chunk_dim( + # self.node_list[end_idx], end_dim, node + # ) + # # TODO: we need to loop every dim + # if isinstance(mix_flow_node_dim, list): + # mix_flow_node_dim = mix_flow_node_dim[0] + # if mix_flow_node_dim is None: + # flow_block = True + # break + # if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + # flow_block = False + # for i in self._get_same_flow_node( + # chunk_info["inputs"], mix_flow_node + # ): + # chunk_info["inputs"].remove(i) + # # else, we need to chunk mix var as well + # else: + # # TODO chunk another value + # flow_block = True + # break + # else: + # raise NotImplementedError("%s not implemented" % node.name) + # if flow_block: + # flow_block = True + # return flow_block, chunk_info inputs_dim = [] remove_inputs = [] @@ -908,6 +929,9 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, input_node ) + # TODO: we need to loop every dim + if isinstance(dim, list): + dim = dim[0] elif user_idx == end_idx: dim = end_dim # n has relation with chunk dim @@ -921,6 +945,8 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind for i in remove_inputs: if i in chunk_info["inputs"]: chunk_info["inputs"].remove(i) + + duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True) # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( @@ -932,6 +958,150 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind return flow_block, chunk_info + def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, + inputs, index_tracer, cur_node_dim, + cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, + next_node_list): + arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node] != arg_dim: + return False + all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim)) + # else add it to list + else: + all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node + all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim'] + cur_node_fix_dim = all_node_info[cur_node]['fix_dim'] + cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) + if cur_node_chunk_dim: + cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node) + cur_node_source = index_tracer._find_source_trace_from_node(cur_node) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if _is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx, + inputs, index_tracer, cur_node_chunk_dim, + cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, + next_node_list) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx): + continue + arg_chunk_dim = all_node_info[arg]['chunk_dim'] + arg_fix_dim = all_node_info[arg]['fix_dim'] + arg_shape = _get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]['chunk_dim'] + if chunk_dim is not None: + input_dict[user_idx] = chunk_dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "args": {}, + } + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + return chunk_info + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: @@ -1055,12 +1225,13 @@ def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): node_source = self.index_tracer._find_source_trace_from_node(node) for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): for k, v in input_node_dim.items(): + # TODO: inherit dim should be list too, int now inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio for dim, source in enumerate(node_source): - if k in source and source[k] == inherit_dim: + if k in source and inherit_dim in source[k]: chunk_ratio = float(chunk_size) / node_shape[dim] return chunk_ratio return 1. @@ -1323,9 +1494,11 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - # must be same trace idx - if start_trace_idx != end_trace_idx: - continue + if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2: + print(1) + self.flow_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim, self.index_tracer + ) # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1343,10 +1516,16 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): ): continue # detect flow meet - flow_block, chunk_info = self.flow_tracer._detect_flow( + # flow_block, chunk_info = self.flow_tracer._detect_flow( + # start_idx, start_dim, end_idx, end_dim, self.index_tracer + # ) + # if flow_block: + # continue + # flow search + chunk_info = self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer ) - if flow_block: + if chunk_info is None: continue # check index copmute if not self.index_tracer.check_index_duplicate(chunk_info): diff --git a/evoformer/evoformer.py b/evoformer/evoformer.py index 0c5ab952a779..cfd2bb2a2529 100644 --- a/evoformer/evoformer.py +++ b/evoformer/evoformer.py @@ -6,6 +6,13 @@ from .triangle import PairStack +def print_memory(init_mem, text=None): + now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem + max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem + print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) + torch.cuda.reset_peak_memory_stats() + + class EvoformerBlock(nn.Module): def __init__(self, d_node, d_pair): @@ -16,9 +23,9 @@ def __init__(self, d_node, d_pair): self.pair_stack = PairStack(d_pair=d_pair) def forward(self, node, pair): - node = node + self.msa_stack(node, pair) + node = self.msa_stack(node, pair) pair = pair + self.communication(node) - pair = pair + self.pair_stack(pair) + pair = self.pair_stack(pair) return node, pair From ded1005667402ee9458afa53852ce2018b1ccb10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:03:08 +0800 Subject: [PATCH 045/209] format code --- chunk_codegen.py | 184 +++++++++++++++++++++++++++++++---------------- 1 file changed, 122 insertions(+), 62 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 3ba082ceb845..eb16361c04fc 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -144,7 +144,9 @@ def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] else: if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: - node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim) + node_to_trace["source"][node_to_dim][node_from_idx].append( + node_from_dim + ) # update inputs source node_to_trace["source"][node_to_dim].update( node_from_trace["source"][node_from_dim] @@ -745,7 +747,6 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): return True - class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -856,7 +857,9 @@ def trace_flow(self): ) return self.flow_trace - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + def _detect_flow( + self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) @@ -945,8 +948,10 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind for i in remove_inputs: if i in chunk_info["inputs"]: chunk_info["inputs"].remove(i) - - duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True) + + duplicate_result, duplicate_dim = index_tracer.check_index_duplicate( + chunk_info, return_dim=True + ) # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( @@ -958,15 +963,25 @@ def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Ind return flow_block, chunk_info - def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, - inputs, index_tracer, cur_node_dim, - cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, - next_node_list): + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + inputs, + index_tracer, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True - + # find arg dim if cur_node_dim is not None: # dim is computed @@ -978,7 +993,7 @@ def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, arg_dim = cur_node_source[cur_node_dim][arg_idx][0] else: arg_dim = None - + # get fix dim arg_fix_dim = [] if cur_node_dim is not None: @@ -986,44 +1001,52 @@ def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, fix_dim_source = cur_node_source[i] if arg_idx in fix_dim_source: arg_fix_dim.append(fix_dim_source[arg_idx][0]) - + # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node] != arg_dim: return False - all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim)) + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) # else add it to list else: - all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim} - + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + next_node_list.append(arg_node) return True - - def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + + def flow_search( + self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) # only single ouput if len(outputs) > 1: return None - + cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}} - + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + while len(cur_node_list) > 0: next_node_list = [] for cur_node in cur_node_list: # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim'] - cur_node_fix_dim = all_node_info[cur_node]['fix_dim'] + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) if cur_node_chunk_dim: - cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node) - cur_node_source = index_tracer._find_source_trace_from_node(cur_node) + cur_node_compute = index_tracer._find_compute_trace_from_node( + cur_node + ) + cur_node_source = index_tracer._find_source_trace_from_node( + cur_node + ) else: cur_node_compute = cur_node_source = None - + # get all valid args arg_list = [] for arg in cur_node.args: @@ -1032,20 +1055,33 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Inde if _is_non_compute_node(arg): continue arg_list.append(arg) - flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx, - inputs, index_tracer, cur_node_chunk_dim, - cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, - next_node_list) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + inputs, + index_tracer, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) if flow_flag == False: return None - + if len(arg_list) == 2: if any(i in cur_node.name for i in ["add", "mul"]): for arg in arg_list: - if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx): + if not ( + start_idx + <= _find_idx_by_name(arg.name, index_tracer.nodes_list) + < end_idx + ): continue - arg_chunk_dim = all_node_info[arg]['chunk_dim'] - arg_fix_dim = all_node_info[arg]['fix_dim'] + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] arg_shape = _get_node_shape(arg) # add all dim as fix dim except chunk dim for i, shape in enumerate(arg_shape): @@ -1061,7 +1097,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Inde else: raise NotImplementedError() cur_node_list = next_node_list - + inputs_dim = [] remove_inputs = [] for input_node in inputs: @@ -1071,7 +1107,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Inde continue user_idx = _find_idx_by_name(user.name, self.node_list) if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]['chunk_dim'] + chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: input_dict[user_idx] = chunk_dim if len(input_dict) == 0: @@ -1081,7 +1117,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Inde for i in remove_inputs: if i in inputs: inputs.remove(i) - + chunk_info = { "region": (start_idx, end_idx), "inputs": inputs, @@ -1091,7 +1127,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: Inde "outputs_dim": end_dim, "args": {}, } - + # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( self.node_list[start_idx : end_idx + 1] @@ -1129,7 +1165,7 @@ def _get_output_node_size(self, n): def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] - if n.op == 'placeholder': + if n.op == "placeholder": new_active.append(n.name) for i in new_active: if i not in active_list: @@ -1168,12 +1204,16 @@ def _remove_deactive_node(self, user, user_to_last_uses, active_list): for i in delete_node: if i in active_list: active_list.remove(i) - - def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): + + def _get_chunk_inputs_size( + self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx + ): nodes_to_delete = [] for chunk_input in chunk_inputs + chunk_inputs_non_chunk: chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users] + chunk_input_users_idx = [ + _find_idx_by_name(i.name, node_list) for i in chunk_input_users + ] if all(i <= chunk_end_idx for i in chunk_input_users_idx): if chunk_input not in nodes_to_delete: nodes_to_delete.append(chunk_input) @@ -1226,7 +1266,9 @@ def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): for k, v in input_node_dim.items(): # TODO: inherit dim should be list too, int now - inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) + inherit_dim = self.index_tracer._find_inherit_dim( + input_node, v, self.index_tracer.nodes_list[k] + ) if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio @@ -1234,7 +1276,7 @@ def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): if k in source and inherit_dim in source[k]: chunk_ratio = float(chunk_size) / node_shape[dim] return chunk_ratio - return 1. + return 1.0 def _get_chunk_delete_node_size( self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names @@ -1295,7 +1337,7 @@ def estimate_chunk_inference_mem( chunk_ratio = 1 # use it to estimate chunk mem chunk_size = 1 chunk_inputs_names = [] - + if use_chunk: chunk_regions = [i["region"] for i in chunk_infos] chunk_starts = [i[0] for i in chunk_regions] @@ -1313,12 +1355,17 @@ def estimate_chunk_inference_mem( if use_chunk and idx in chunk_starts: chunk_within = True chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) + act_memory += self._get_output_node_size( + chunk_outputs[chunk_region_idx] + ) / (1024**2) # determine chunk ratio for current node if chunk_within: chunk_ratio = self._get_chunk_ratio( - node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size + node, + chunk_inputs[chunk_region_idx], + chunk_inputs_dim[chunk_region_idx], + chunk_size, ) # if node is placeholder, just add the size of the node @@ -1353,18 +1400,18 @@ def estimate_chunk_inference_mem( / (1024**2) ) # delete unused vars not in chunk_input_list - # we can't delete input nodes until chunk ends + # we can't delete input nodes until chunk ends if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses_no_free_var, chunk_ratio, - chunk_inputs_names + chunk_inputs_names, ) / (1024**2) else: - act_memory -= (self._get_delete_node_size( + act_memory -= self._get_delete_node_size( node, user_to_last_uses_no_free_var, chunk_inputs_names - ) / (1024**2)) + ) / (1024**2) # log active node, only effective without chunk self._add_active_node(node, active_node_list) @@ -1376,11 +1423,11 @@ def estimate_chunk_inference_mem( self._get_output_node_size(node) * chunk_ratio / (1024**2) ) act_memory -= self._get_chunk_inputs_size( - chunk_inputs[chunk_region_idx], - chunk_inputs_non_chunk[chunk_region_idx], + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], node_list, - chunk_regions[chunk_region_idx][1] - ) / (1024**2) + chunk_regions[chunk_region_idx][1], + ) / (1024**2) chunk_within = False chunk_ratio = 1 chunk_region_idx = None @@ -1436,7 +1483,7 @@ def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): active_node_num = [len(i) for i in active_node] min_active_node_num = min(active_node_num[free_var_num:]) threshold = max(free_var_num, min_active_node_num) - + # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num @@ -1494,7 +1541,12 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2: + if ( + start_idx == 199 + and end_idx == 229 + and start_dim == 2 + and end_dim == 2 + ): print(1) self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer @@ -1576,7 +1628,7 @@ def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None return best_region - + def _is_legal_region(self, cur_chunk_info, chunk_infos): (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] if cur_chunk_info in chunk_infos: @@ -1585,11 +1637,13 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos): return False for i in chunk_infos: region = i["region"] - if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0])): + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): return False return True - + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1600,7 +1654,9 @@ def _step_search(self, mem_peak, active_node, chunk_regions): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions) + best_chunk_region = self._search_best_chunk_region( + possible_chunk_regions, chunk_regions + ) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1667,7 +1723,11 @@ def _gen_loop_end( chunk_slice = _gen_chunk_slice_dim( chunk_outputs_dim, "chunk_idx", chunk_output_shape ) - context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name) + context = " chunk_result%s = %s; %s = None\n" % ( + chunk_slice, + chunk_outputs_name, + chunk_outputs_name, + ) context += ( chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" ) From 774d34f1aa2f9534557dd4a0ca866392a496e448 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 13:41:10 +0800 Subject: [PATCH 046/209] refactor flow search --- chunk_codegen.py | 78 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index eb16361c04fc..0b0a164fe999 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1004,7 +1004,7 @@ def _assgin_single_node_flow( # if already in node_info, arg dim must be same if arg_node in all_node_info: - if all_node_info[arg_node] != arg_dim: + if all_node_info[arg_node]['chunk_dim'] != arg_dim: return False all_node_info[arg_node]["fix_dim"] = list( set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) @@ -1128,14 +1128,68 @@ def flow_search( "args": {}, } + # move useless nodes ahead of loop + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info['chunk_dim'] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break; break; break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in prepose_nodes: + chunk_node_list.remove(n) non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] + chunk_node_list ) for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: + if i not in chunk_info["inputs"] and i not in prepose_nodes: chunk_info["inputs_non_chunk"].append(i) - + return chunk_info @@ -1541,16 +1595,6 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - if ( - start_idx == 199 - and end_idx == 229 - and start_dim == 2 - and end_dim == 2 - ): - print(1) - self.flow_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim, self.index_tracer - ) # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1567,12 +1611,6 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_idx, end_dim, end_node, end_idx ): continue - # detect flow meet - # flow_block, chunk_info = self.flow_tracer._detect_flow( - # start_idx, start_dim, end_idx, end_dim, self.index_tracer - # ) - # if flow_block: - # continue # flow search chunk_info = self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer From 522f01741864f3565f8e97837ecc7289774ee127 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 13:41:51 +0800 Subject: [PATCH 047/209] code style --- chunk_codegen.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 0b0a164fe999..a8b970116d1d 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1004,7 +1004,7 @@ def _assgin_single_node_flow( # if already in node_info, arg dim must be same if arg_node in all_node_info: - if all_node_info[arg_node]['chunk_dim'] != arg_dim: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False all_node_info[arg_node]["fix_dim"] = list( set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) @@ -1132,16 +1132,19 @@ def flow_search( # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): - if node_info['chunk_dim'] is None: + if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node + maybe_prepose_nodes.sort( + key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), + reverse=True, + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] tmp_cur_related_prepose_nodes = [] prepose_flag = True - + # loop cur node's all arg until out of chunk while len(tmp_cur_prepose_nodes) > 0: tmp_next_prepose_nodes = [] @@ -1151,20 +1154,28 @@ def flow_search( if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop - if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx): + if not ( + start_idx + <= _find_idx_by_name( + cur_prepose_node_arg.name, self.node_list + ) + < end_idx + ): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: tmp_next_prepose_nodes.append(cur_prepose_node_arg) else: prepose_flag = False - break; break; break + break + break + break # non compute op else: tmp_next_prepose_nodes.append(cur_prepose_node_arg) tmp_cur_prepose_nodes = tmp_next_prepose_nodes - + if prepose_flag == False: maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) continue @@ -1175,21 +1186,21 @@ def flow_search( if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)) + prepose_nodes.sort( + key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list) + ) chunk_info["args"]["prepose_nodes"] = prepose_nodes - + # we need to log input nodes to avoid deleteing them in the loop chunk_node_list = self.node_list[start_idx : end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs for n in prepose_nodes: chunk_node_list.remove(n) - non_chunk_inputs = _find_chunk_all_input_nodes( - chunk_node_list - ) + non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: if i not in chunk_info["inputs"] and i not in prepose_nodes: chunk_info["inputs_non_chunk"].append(i) - + return chunk_info From d309e9338bde716ca356af8a27e0c484e97abbd9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 14:26:12 +0800 Subject: [PATCH 048/209] adapt codegen to prepose node --- chunk_codegen.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index a8b970116d1d..e3a7643d7499 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1198,7 +1198,7 @@ def flow_search( chunk_node_list.remove(n) non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: - if i not in chunk_info["inputs"] and i not in prepose_nodes: + if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) return chunk_info @@ -1425,6 +1425,7 @@ def estimate_chunk_inference_mem( ) / (1024**2) # determine chunk ratio for current node + # TODO: adapt to prepose node memory if chunk_within: chunk_ratio = self._get_chunk_ratio( node, @@ -1602,7 +1603,6 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): chunk_infos = [] for end_dim, end_trace_idx in enumerate(end_trace["idx"]): if len(start_traces) > 1: - # TODO: implement multi input chunk continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): @@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output - # TODO: it is unsafe to remove non compute node here for node in nodes: for output_node in node.users.keys(): if ( @@ -1900,6 +1899,8 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] + + chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search] node_idx = 0 region_idx = 0 @@ -1911,7 +1912,11 @@ def emit_code_with_chunk( if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) - + # add prepose nodes + for i in chunk_prepose_nodes[region_idx]: + prepose_node = node_list[_find_idx_by_name(i.name, node_list)] + emit_node_func(prepose_node, body) + delete_unused_value_func(prepose_node, body, chunk_inputs_names) # add for loop body.append( _gen_loop_start( @@ -1922,20 +1927,22 @@ def emit_code_with_chunk( ) if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) - + if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]): + pass + else: + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) else: emit_node_func(node, body) if node_idx not in chunk_inputs: From 49ba619085c33eef372e73b6a45aecdc3d37937f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 14:26:43 +0800 Subject: [PATCH 049/209] code style --- chunk_codegen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e3a7643d7499..40196285ec8c 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1899,7 +1899,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] - + chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search] node_idx = 0 @@ -1933,7 +1933,9 @@ def emit_code_with_chunk( emit_node_func(node, body) # replace input var with chunk var for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + for idx, dim in chunk_inputs_dim[region_idx][ + input_node_idx + ].items(): if idx == node_idx: chunk_slice = _gen_chunk_slice_dim( dim, "chunk_idx", _get_node_shape(input_node) From 4d89525fc2f828c9c65bf4077b677db9a78c8466 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 14:28:49 +0800 Subject: [PATCH 050/209] remove abandoned function --- chunk_codegen.py | 106 ----------------------------------------------- 1 file changed, 106 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 40196285ec8c..e2786d5e244f 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -857,112 +857,6 @@ def trace_flow(self): ) return self.flow_trace - def _detect_flow( - self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer - ): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": start_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "args": {}, - } - flow_block = False - - # TODO don't allow multi outputs now - if len(outputs) > 1: - flow_block = True - return flow_block, chunk_info - - # for idx in range(start_idx, end_idx + 1): - # node = self.node_list[idx] - # mix_flow_node = self._get_flow_mix_node(node) - # if mix_flow_node is None: - # continue - - # # if there is a flow mix, op must be in [mul, add, matmul] - # # element-wise op requires dim to be equal in every dim - # if any(n in node.name for n in ["mul", "add"]): - # for i in node.args: - # if type(i) == type(mix_flow_node) and i != mix_flow_node: - # main_flow_var = i - # # if mix flow is a broadcast in chunk dim, - # # TODO: need to move that flow out of the chunk - # mix_flow_node_dim = index_tracer.get_node_chunk_dim( - # self.node_list[end_idx], end_dim, node - # ) - # # TODO: we need to loop every dim - # if isinstance(mix_flow_node_dim, list): - # mix_flow_node_dim = mix_flow_node_dim[0] - # if mix_flow_node_dim is None: - # flow_block = True - # break - # if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - # flow_block = False - # for i in self._get_same_flow_node( - # chunk_info["inputs"], mix_flow_node - # ): - # chunk_info["inputs"].remove(i) - # # else, we need to chunk mix var as well - # else: - # # TODO chunk another value - # flow_block = True - # break - # else: - # raise NotImplementedError("%s not implemented" % node.name) - # if flow_block: - # flow_block = True - # return flow_block, chunk_info - - inputs_dim = [] - remove_inputs = [] - for input_node in chunk_info["inputs"]: - input_dict = {} - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - dim = None - if start_dim <= user_idx < end_idx: - dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, input_node - ) - # TODO: we need to loop every dim - if isinstance(dim, list): - dim = dim[0] - elif user_idx == end_idx: - dim = end_dim - # n has relation with chunk dim - if dim is not None and _get_node_shape(user)[dim] != 1: - input_dict[user_idx] = dim - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - chunk_info["inputs_dim"] = inputs_dim - for i in remove_inputs: - if i in chunk_info["inputs"]: - chunk_info["inputs"].remove(i) - - duplicate_result, duplicate_dim = index_tracer.check_index_duplicate( - chunk_info, return_dim=True - ) - - # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] - ) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - return flow_block, chunk_info - def _assgin_single_node_flow( self, arg_node, From 4f5e105af30fccb4b0595edd341bdd7a4b226aa9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 15:34:41 +0800 Subject: [PATCH 051/209] remove flow tracer --- chunk_codegen.py | 171 ++++++++--------------------------------------- 1 file changed, 27 insertions(+), 144 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e2786d5e244f..838f53949de7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -67,7 +67,7 @@ def _is_non_compute_node_except_placeholder_output(node): class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm - self.nodes_list = list(gm.graph.nodes) + self.node_list = list(gm.graph.nodes) self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] @@ -75,7 +75,7 @@ def __init__(self, gm) -> None: def _init_idx_trace_list(self): idx_trace_list = [] - for n in self.nodes_list: + for n in self.node_list: if _get_node_shape(n) != None: cur_trace = { "idx": [None for _ in range(len(_get_node_shape(n)))], @@ -136,7 +136,7 @@ def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False node_from_trace = self._find_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) node_to_trace = self._find_trace_from_node(node_to) - node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_from_idx = _find_idx_by_name(node_from.name, self.node_list) if init: node_to_trace["source"][node_to_dim] = {} # add dim to cur new source @@ -210,7 +210,7 @@ def _find_trace_from_node(self, node): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) node_dict = self.idx_trace_list[node_idx] return node_dict @@ -224,7 +224,7 @@ def _find_source_trace_from_node(self, node): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) node_dict = self.idx_trace_list[node_idx] return node_dict["source"] @@ -237,7 +237,7 @@ def _find_idx_trace_from_node(self, node): Returns: idx (list): idx of the node """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) return self.idx_trace_list[node_idx]["idx"] def _find_compute_trace_from_node(self, node): @@ -249,7 +249,7 @@ def _find_compute_trace_from_node(self, node): Returns: compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) return self.idx_trace_list[node_idx]["compute"] def _assign_index_as_input(self, node, node_idx, input_node=None): @@ -262,7 +262,7 @@ def _assign_index_as_input(self, node, node_idx, input_node=None): """ if input_node == None: input_node = node.args[0] - input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] new_idx_trace = copy.deepcopy(input_node_idx_trace) @@ -591,7 +591,7 @@ def _merge_equal_idx(self): ] def trace_index(self): - for idx, node in enumerate(self.nodes_list): + for idx, node in enumerate(self.node_list): if node.op == "placeholder": self._assign_all_index(node, idx) elif node.op == "call_method": @@ -655,7 +655,7 @@ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node Returns: bool: True if check pass """ - start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list) + start_node_idx = _find_idx_by_name(start_node.name, self.node_list) end_node_trace = self._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] sorted_source = sorted( @@ -690,14 +690,14 @@ def check_index_compute(self, start_idx, end_dim, end_node, end_idx): def get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) dim_source = node_from_source[node_from_dim] - node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) + node_to_idx = _find_idx_by_name(node_to.name, self.node_list) for k, v in dim_source.items(): if k == node_to_idx: return v return None def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): if ( @@ -711,11 +711,11 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k]) + inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) if inherit_dim: input_dim_after_node[k] = inherit_dim - for node in self.nodes_list[ + for node in self.node_list[ chunk_infos["region"][0] : chunk_infos["region"][1] + 1 ]: if _is_non_compute_node_except_placeholder(node): @@ -746,124 +746,11 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): else: return True - -class FlowTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) - self.flow_trace = {} - - def _add_trace(self, name): - self.flow_trace[name] = [] - - def _add_node(self, trace_name, node): - self.flow_trace[trace_name].append( - {"node": node, "inside_depend": [], "outside_depend": []} - ) - - def _add_inside_depend(self, flow_name, node, inside_depend_node): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["inside_depend"].append(inside_depend_node) - return - raise RuntimeError("node not found") - - def _add_outside_depend( - self, flow_name, node, outside_depend_node, outside_depend_trace - ): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["outside_depend"].append({outside_depend_trace: outside_depend_node}) - return - raise RuntimeError("node not found") - - def _init_trace(self): - for i in self.node_list: - if i.op == "placeholder": - self._add_trace(i.name) - self._add_node(i.name, i) - - def _find_flow_for_node(self, node): - if type(self.node_list[0]) != type(node): - return None - if _is_non_compute_node_except_placeholder(node): - return None - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name - if any(i in node.name for i in ["ones_like"]): - self._add_trace(node.name) - self._add_node(node.name, node) - return node.name - raise RuntimeError("node not found") - - def _find_first_valid_flow(self, flow): - for i in flow: - if i is not None: - return i - raise RuntimeError("invalid flow") - - def find_node_flow(self, node): - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name, i - raise RuntimeError("invalid node") - - def _get_flow_mix_node(self, node): - if _is_non_compute_node(node): - return None - _, node_trace = self.find_node_flow(node) - if len(node_trace["outside_depend"]) == 0: - return None - elif len(node_trace["outside_depend"]) > 1: - raise NotImplementedError - vars = list(node_trace["outside_depend"][0].values())[0] - return vars - - def _get_same_flow_node(self, node_list, node): - name, _ = self.find_node_flow(node) - result = [] - for i in self.flow_trace[name]: - if i["node"] in node_list: - result.append(i["node"]) - return result - - def trace_flow(self): - # init trace - self._init_trace() - - for node in self.node_list: - # skip if non compute node - if all( - type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) - for arg in node.args - ) or _is_non_compute_node(node): - continue - - node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] - - node_domin_flow = self._find_first_valid_flow(node_input_flows) - self._add_node(node_domin_flow, node) - for node_input_flow, arg in zip(node_input_flows, node.args): - if node_input_flow is None: - continue - elif node_input_flow == node_domin_flow: - self._add_inside_depend(node_domin_flow, node, arg) - else: - self._add_outside_depend( - node_domin_flow, node, arg, node_input_flow - ) - return self.flow_trace - def _assgin_single_node_flow( self, arg_node, start_idx, end_idx, - inputs, - index_tracer, cur_node_dim, cur_node_compute, cur_node_source, @@ -871,7 +758,7 @@ def _assgin_single_node_flow( all_node_info, next_node_list, ): - arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) + arg_idx = _find_idx_by_name(arg_node.name, self.node_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True @@ -911,7 +798,7 @@ def _assgin_single_node_flow( return True def flow_search( - self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + self, start_idx, start_dim, end_idx, end_dim ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] @@ -920,7 +807,7 @@ def flow_search( if len(outputs) > 1: return None - cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node + cur_node_list = [self.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -930,12 +817,12 @@ def flow_search( # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) + cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: - cur_node_compute = index_tracer._find_compute_trace_from_node( + cur_node_compute = self._find_compute_trace_from_node( cur_node ) - cur_node_source = index_tracer._find_source_trace_from_node( + cur_node_source = self._find_source_trace_from_node( cur_node ) else: @@ -953,8 +840,6 @@ def flow_search( arg, start_idx, end_idx, - inputs, - index_tracer, cur_node_chunk_dim, cur_node_compute, cur_node_source, @@ -970,7 +855,7 @@ def flow_search( for arg in arg_list: if not ( start_idx - <= _find_idx_by_name(arg.name, index_tracer.nodes_list) + <= _find_idx_by_name(arg.name, self.node_list) < end_idx ): continue @@ -1029,7 +914,7 @@ def flow_search( if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) maybe_prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), + key=lambda x: _find_idx_by_name(x.name, self.node_list), reverse=True, ) # from last node to first node prepose_nodes = [] @@ -1081,7 +966,7 @@ def flow_search( maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list) + key=lambda x: _find_idx_by_name(x.name, self.node_list) ) chunk_info["args"]["prepose_nodes"] = prepose_nodes @@ -1226,9 +1111,9 @@ def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): for k, v in input_node_dim.items(): # TODO: inherit dim should be list too, int now inherit_dim = self.index_tracer._find_inherit_dim( - input_node, v, self.index_tracer.nodes_list[k] + input_node, v, self.index_tracer.node_list[k] ) - if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): + if k == _find_idx_by_name(node.name, self.index_tracer.node_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio for dim, source in enumerate(node_source): @@ -1412,8 +1297,6 @@ def __init__(self, gm) -> None: self.node_list = list(gm.graph.nodes) self.index_tracer = IndexTracer(gm) self.index_tracer.trace_index() - self.flow_tracer = FlowTracer(gm) - self.flow_tracer.trace_flow() self.memory_estimator = MemoryEstimator(self.index_tracer) def _find_peak_node(self, mem_peak): @@ -1517,8 +1400,8 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): ): continue # flow search - chunk_info = self.flow_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim, self.index_tracer + chunk_info = self.index_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim ) if chunk_info is None: continue From fa5e6fbf96448ebff1dc682e749a3f73a5a9c2b5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 15:38:37 +0800 Subject: [PATCH 052/209] code style --- chunk_codegen.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 838f53949de7..e80b0fd9be77 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -65,9 +65,8 @@ def _is_non_compute_node_except_placeholder_output(node): class IndexTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) + def __init__(self, node_list) -> None: + self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] @@ -797,9 +796,7 @@ def _assgin_single_node_flow( next_node_list.append(arg_node) return True - def flow_search( - self, start_idx, start_dim, end_idx, end_dim - ): + def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) @@ -819,12 +816,8 @@ def flow_search( cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node( - cur_node - ) - cur_node_source = self._find_source_trace_from_node( - cur_node - ) + cur_node_compute = self._find_compute_trace_from_node(cur_node) + cur_node_source = self._find_source_trace_from_node(cur_node) else: cur_node_compute = cur_node_source = None @@ -965,9 +958,7 @@ def flow_search( if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, self.node_list) - ) + prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list)) chunk_info["args"]["prepose_nodes"] = prepose_nodes # we need to log input nodes to avoid deleteing them in the loop @@ -1295,7 +1286,9 @@ class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.node_list = list(gm.graph.nodes) - self.index_tracer = IndexTracer(gm) + self.index_tracer = IndexTracer( + self.node_list + ) # node list shared in index tracer self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) From e0ae68e736cb56015fd1316113d52affaaf27749 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 15:49:04 +0800 Subject: [PATCH 053/209] code style --- chunk_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e80b0fd9be77..6e772aa8a56a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1497,8 +1497,8 @@ def search_region(self): chunk_info = self._step_search(mem_peak, active_node, chunk_infos) if chunk_info is None: break - chunk_infos.append(chunk_info) + ( mem_peak, _, From 884a228ea674b02998575776b0069b15de0b7a10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:06:07 +0800 Subject: [PATCH 054/209] reorder nodes --- chunk_codegen.py | 127 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 6e772aa8a56a..4b3b04d93b91 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -71,6 +71,7 @@ def __init__(self, node_list) -> None: self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 + self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} def _init_idx_trace_list(self): idx_trace_list = [] @@ -973,6 +974,91 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): return chunk_info + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + _find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = _find_idx_by_name(n.name, self.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + chunk_info["inputs_dim"][idx] = new_input_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.node_list[old_idx] + self.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] + self.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: @@ -1476,6 +1562,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): best_chunk_region = self._search_best_chunk_region( possible_chunk_regions, chunk_regions ) + best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1670,8 +1757,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] - chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search] - + node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False @@ -1682,12 +1768,6 @@ def emit_code_with_chunk( if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) - # add prepose nodes - for i in chunk_prepose_nodes[region_idx]: - prepose_node = node_list[_find_idx_by_name(i.name, node_list)] - emit_node_func(prepose_node, body) - delete_unused_value_func(prepose_node, body, chunk_inputs_names) - # add for loop body.append( _gen_loop_start( chunk_inputs[region_idx], @@ -1697,24 +1777,19 @@ def emit_code_with_chunk( ) if within_chunk_region: - if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]): - pass - else: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][ - input_node_idx - ].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) else: emit_node_func(node, body) if node_idx not in chunk_inputs: From 51ef8384c153f46dcbb74c26eec523ad7cd0d51c Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:25:36 +0800 Subject: [PATCH 055/209] finish node reorder --- chunk_codegen.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 4b3b04d93b91..9623a9d9bbe2 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1238,7 +1238,7 @@ def _print_compute_op_mem_log(self, log, nodes, title=None): def estimate_chunk_inference_mem( self, - gm: torch.fx.GraphModule, + node_list, chunk_infos=None, ): act_memory = 0.0 @@ -1247,7 +1247,6 @@ def estimate_chunk_inference_mem( active_node_list = [] active_node_list_log = [] not_contiguous_list = [] - node_list = list(gm.graph.nodes) user_to_last_uses = self._get_last_usr(node_list) user_to_last_uses_no_free_var = self._get_last_usr(node_list) _delete_free_var_from_last_use(user_to_last_uses_no_free_var) @@ -1281,7 +1280,6 @@ def estimate_chunk_inference_mem( ) / (1024**2) # determine chunk ratio for current node - # TODO: adapt to prepose node memory if chunk_within: chunk_ratio = self._get_chunk_ratio( node, @@ -1371,10 +1369,7 @@ def estimate_chunk_inference_mem( class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm - self.node_list = list(gm.graph.nodes) - self.index_tracer = IndexTracer( - self.node_list - ) # node list shared in index tracer + self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) @@ -1385,7 +1380,7 @@ def _find_peak_node(self, mem_peak): def _get_free_var(self): free_var_idx = [] - for idx, n in enumerate(self.node_list): + for idx, n in enumerate(self.index_tracer.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx @@ -1455,13 +1450,13 @@ def _is_not_compute(self, trace, chunk_range, dim_idx): def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] - end_node = self.node_list[end_idx] + end_node = self.index_tracer.node_list[end_idx] chunk_infos = [] - for end_dim, end_trace_idx in enumerate(end_trace["idx"]): + for end_dim, _ in enumerate(end_trace["idx"]): if len(start_traces) > 1: continue for start_node, start_trace in start_traces.items(): - for start_dim, start_trace_idx in enumerate(start_trace["idx"]): + for start_dim, _ in enumerate(start_trace["idx"]): # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1494,7 +1489,7 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.node_list): + for _, n in enumerate(self.index_tracer.node_list): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not _is_non_compute_node_except_placeholder( @@ -1507,8 +1502,8 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if _is_non_compute_node( - self.node_list[start_idx] - ) or _is_non_compute_node(self.node_list[end_idx]): + self.index_tracer.node_list[start_idx] + ) or _is_non_compute_node(self.index_tracer.node_list[end_idx]): continue # select free dim @@ -1577,7 +1572,9 @@ def search_region(self): init_mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list + ) mem_peak = init_mem_peak while True: @@ -1590,7 +1587,9 @@ def search_region(self): mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos) + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos + ) if self._stop_search(init_mem_peak, mem_peak): break return chunk_infos From 9b1b890347f345f1c4de2a0991e250dcaf94365a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:32:11 +0800 Subject: [PATCH 056/209] update run --- chunk_codegen_run.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index ae4653d6545b..3a3b3c599e3e 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,15 +32,25 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node0 = node.clone() + # pair0 = pair.clone() + # model.graph(node0, pair0, now_mem) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) + + torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 with torch.no_grad(): - node0 = node.clone() - pair0 = pair.clone() - node1, pair1 = gm(node0, pair0) + node1 = node.clone() + pair1 = pair.clone() + gm(node1, pair1) new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) - + print("gm now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) + # test forward with torch.no_grad(): non_fx_out = model(node, pair) From 786a398a6bdea395e2ca8ddde87c87c8470d971b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:42:51 +0800 Subject: [PATCH 057/209] code style --- chunk_codegen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 9623a9d9bbe2..f87a3a132e78 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -920,9 +920,13 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): # loop cur node's all arg until out of chunk while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break tmp_next_prepose_nodes = [] tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break for cur_prepose_node_arg in cur_prepose_node.args: if type(cur_prepose_node_arg) != type(cur_prepose_node): continue @@ -942,8 +946,6 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): else: prepose_flag = False break - break - break # non compute op else: tmp_next_prepose_nodes.append(cur_prepose_node_arg) From 1b8a066592821870bb8f7a6fce338481efd5140b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 26 Dec 2022 15:28:01 +0800 Subject: [PATCH 058/209] add chunk select class --- chunk_codegen.py | 80 +++++++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index f87a3a132e78..cdd0b1077487 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1368,12 +1368,60 @@ def estimate_chunk_inference_mem( return act_memory_peak_log, act_memory_after_node_log, active_node_list_log +class ChunkSelector(object): + def __init__(self, index_tracer: IndexTracer, stratge) -> None: + self.index_tracer = index_tracer + assert stratge in ['min_memory', 'fit_memory'] + self.stratge = stratge + self.max_memory = 800 # MB + + def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos): + if self.stratge == 'min_memory': + best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) + elif self.stratge == 'fit_memory': + pass + else: + raise RuntimeError() + return best_region + + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): + max_region_range = 0 + best_region = None + while len(possible_chunk_regions) > 0: + for i in possible_chunk_regions: + if i["region"][1] - i["region"][0] > max_region_range: + best_region = i + max_region_range = i["region"][1] - i["region"][0] + if self._is_legal_region(best_region, chunk_infos): + break + possible_chunk_regions.remove(i) + max_region_range = 0 + best_region = None + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): + return False + return True + + class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) + self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory") def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1516,36 +1564,6 @@ def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region.extend(chunk_info) return possible_chunk_region - def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos): - max_region_range = 0 - best_region = None - while len(possible_chunk_regions) > 0: - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_region = i - max_region_range = i["region"][1] - i["region"][0] - if self._is_legal_region(best_region, chunk_infos): - break - possible_chunk_regions.remove(i) - max_region_range = 0 - best_region = None - return best_region - - def _is_legal_region(self, cur_chunk_info, chunk_infos): - (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] - if cur_chunk_info in chunk_infos: - return False - if chunk_region_end < chunk_region_start: - return False - for i in chunk_infos: - region = i["region"] - if not ( - (chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0]) - ): - return False - return True - def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1556,7 +1574,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region( + best_chunk_region = self.chunk_selector._select_best_chunk_region( possible_chunk_regions, chunk_regions ) best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) From 8f5a0edfab3d9c4636333cba2dcdbb7f2fa74181 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 26 Dec 2022 23:08:49 +0800 Subject: [PATCH 059/209] add chunk select --- chunk_codegen.py | 147 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 35 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index cdd0b1077487..330f3dec611c 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -69,7 +69,7 @@ def __init__(self, node_list) -> None: self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] - self.idx_view_list = [] + self.idx_view_list = {} self.idx_count = -1 self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} @@ -576,7 +576,7 @@ def _assign_view_reshape_index(self, node, node_idx): "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], "dim_to": dim_to, } - self.idx_view_list.append(view_dict) + self.idx_view_list[node] = view_dict def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) @@ -702,7 +702,7 @@ def _find_inherit_dim(self, input_node, input_dim, node): for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] - and input_dim in node_trace_source[node_dim][input_node_idx] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] ): return node_dim return None @@ -875,6 +875,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): remove_inputs = [] for input_node in inputs: input_dict = {} + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) for user in input_node.users.keys(): if _is_non_compute_node(user): continue @@ -882,7 +883,11 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - input_dict[user_idx] = chunk_dim + user_source = self._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None if len(input_dict) == 0: remove_inputs.append(input_node) else: @@ -898,6 +903,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): "inputs_dim": inputs_dim, "outputs": outputs, "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, "args": {}, } @@ -974,6 +980,26 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info['region'] + reshape_size = {} + for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]: + if any(i in node.name for i in ['reshape', 'view']): + reshape_args = node.args[1:] + reshape_log = self.idx_view_list[node] + chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim'] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log['dim_to']: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = "chunk_size" + chunk_info['reshape_size'] = reshape_size return chunk_info def _get_reorder_map(self, chunk_info): @@ -1183,23 +1209,15 @@ def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): not_contiguous_list.append(node) return mem - def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): + def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): + if node not in chunk_node_dim: + return 1.0 node_shape = _get_node_shape(node) - node_source = self.index_tracer._find_source_trace_from_node(node) - for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): - for k, v in input_node_dim.items(): - # TODO: inherit dim should be list too, int now - inherit_dim = self.index_tracer._find_inherit_dim( - input_node, v, self.index_tracer.node_list[k] - ) - if k == _find_idx_by_name(node.name, self.index_tracer.node_list): - chunk_ratio = float(chunk_size) / node_shape[inherit_dim] - return chunk_ratio - for dim, source in enumerate(node_source): - if k in source and inherit_dim in source[k]: - chunk_ratio = float(chunk_size) / node_shape[dim] - return chunk_ratio - return 1.0 + chunk_dim = chunk_node_dim[node]['chunk_dim'] + if chunk_dim is None: + return 1.0 + else: + return float(chunk_size) / node_shape[chunk_dim] def _get_chunk_delete_node_size( self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names @@ -1242,6 +1260,7 @@ def estimate_chunk_inference_mem( self, node_list, chunk_infos=None, + print_mem=False, ): act_memory = 0.0 act_memory_peak_log = [] @@ -1271,6 +1290,7 @@ def estimate_chunk_inference_mem( j.name for i in chunk_inputs_non_chunk for j in i ] chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor @@ -1285,8 +1305,7 @@ def estimate_chunk_inference_mem( if chunk_within: chunk_ratio = self._get_chunk_ratio( node, - chunk_inputs[chunk_region_idx], - chunk_inputs_dim[chunk_region_idx], + chunk_node_dim[chunk_region_idx], chunk_size, ) @@ -1357,11 +1376,12 @@ def estimate_chunk_inference_mem( act_memory_after_node_log.append(act_memory) active_node_list_log.append(copy.deepcopy(active_node_list)) - print("with chunk" if use_chunk else "without chunk") - # self._print_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_mem_log(act_memory_after_node_log, node_list, "after") - self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") + if print_mem: + print("with chunk" if use_chunk else "without chunk") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1369,21 +1389,70 @@ def estimate_chunk_inference_mem( class ChunkSelector(object): - def __init__(self, index_tracer: IndexTracer, stratge) -> None: + def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge): self.index_tracer = index_tracer + self.memory_estimator = memory_estimator assert stratge in ['min_memory', 'fit_memory'] self.stratge = stratge - self.max_memory = 800 # MB + self.max_memory = 600 # MB - def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos): + def _select_best_chunk_region(self, possible_chunk_regions, + chunk_infos, peak_node, max_chunk_region, mem_peak): if self.stratge == 'min_memory': best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) elif self.stratge == 'fit_memory': - pass + best_region = self._select_fit_memory_chunk_region( + possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak) else: raise RuntimeError() return best_region + def _select_fit_memory_chunk_region(self, possible_chunk_regions, + chunk_infos, peak_node, max_chunk_region, mem_peak): + # stop chunk if max memory satisfy memory limit + if max(mem_peak) < self.max_memory: + return None + + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_chunk_infos = chunk_infos + [region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + if cur_chunk_region_max_peak < self.max_memory: + regions_dict.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]), + }) + # no region found + if len(regions_dict) == 0: + return None + + # select the min chunk len + chunk_len = [i["chunk_len"] for i in regions_dict] + best_region_idx = chunk_len.index(min(chunk_len)) + best_region = regions_dict[best_region_idx]["chunk_info"] + return best_region + + def _get_compute_node_num(self, start, end): + count = 0 + for i in self.index_tracer.node_list[start: end+1]: + if _is_non_compute_node(i): + count += 1 + return count + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None @@ -1421,7 +1490,7 @@ def __init__(self, gm) -> None: self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory") + self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory") def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1575,7 +1644,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): max_chunk_region, peak_node ) best_chunk_region = self.chunk_selector._select_best_chunk_region( - possible_chunk_regions, chunk_regions + possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) return best_chunk_region @@ -1608,7 +1677,7 @@ def search_region(self): _, active_node, ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos + self.index_tracer.node_list, chunk_infos, print_mem=True ) if self._stop_search(init_mem_peak, mem_peak): break @@ -1736,6 +1805,13 @@ def _replace_name(context, name_from, name_to): return context +def _replace_reshape_size(context, node_name, reshape_size_dict): + if node_name not in reshape_size_dict: + return context + for size_name, size_value in reshape_size_dict[node_name].items(): + context = context.replace(size_name, size_value) + return context + def emit_code_with_chunk( body, ckpt_func, @@ -1802,11 +1878,12 @@ def emit_code_with_chunk( for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) + dim[0], "chunk_idx", _get_node_shape(input_node) ) body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) + body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size']) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: From 378a49dc6c259773cdc198841a75137f7c6edc7f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 27 Dec 2022 09:48:59 +0800 Subject: [PATCH 060/209] code style --- chunk_codegen.py | 101 +++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 38 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 330f3dec611c..1255852d777d 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -982,24 +982,24 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): # reassgin reshape size, some size may have changed due to chunk chunk_info = self._reassgin_reshape_size(chunk_info) - + return chunk_info - + def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info['region'] + chunk_region = chunk_info["region"] reshape_size = {} - for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]: - if any(i in node.name for i in ['reshape', 'view']): + for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim'] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] reshape_size[node.name] = {} for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log['dim_to']: + if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: reshape_size[node.name][reshape_arg.name] = "chunk_size" - chunk_info['reshape_size'] = reshape_size + chunk_info["reshape_size"] = reshape_size return chunk_info def _get_reorder_map(self, chunk_info): @@ -1213,7 +1213,7 @@ def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): if node not in chunk_node_dim: return 1.0 node_shape = _get_node_shape(node) - chunk_dim = chunk_node_dim[node]['chunk_dim'] + chunk_dim = chunk_node_dim[node]["chunk_dim"] if chunk_dim is None: return 1.0 else: @@ -1381,7 +1381,9 @@ def estimate_chunk_inference_mem( # self._print_mem_log(act_memory_peak_log, node_list, "peak") # self._print_mem_log(act_memory_after_node_log, node_list, "after") self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log( + act_memory_after_node_log, node_list, "after" + ) # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1389,30 +1391,41 @@ def estimate_chunk_inference_mem( class ChunkSelector(object): - def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge): + def __init__( + self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge + ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator - assert stratge in ['min_memory', 'fit_memory'] + assert stratge in ["min_memory", "fit_memory"] self.stratge = stratge self.max_memory = 600 # MB - - def _select_best_chunk_region(self, possible_chunk_regions, - chunk_infos, peak_node, max_chunk_region, mem_peak): - if self.stratge == 'min_memory': - best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) - elif self.stratge == 'fit_memory': + + def _select_best_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + if self.stratge == "min_memory": + best_region = self._select_min_memory_chunk_region( + possible_chunk_regions, chunk_infos + ) + elif self.stratge == "fit_memory": best_region = self._select_fit_memory_chunk_region( - possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak) + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) else: raise RuntimeError() return best_region - - def _select_fit_memory_chunk_region(self, possible_chunk_regions, - chunk_infos, peak_node, max_chunk_region, mem_peak): + + def _select_fit_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): # stop chunk if max memory satisfy memory limit if max(mem_peak) < self.max_memory: return None - + # remove illegal regions illegal_regions = [] for i in possible_chunk_regions: @@ -1421,38 +1434,45 @@ def _select_fit_memory_chunk_region(self, possible_chunk_regions, for i in illegal_regions: if i in possible_chunk_regions: possible_chunk_regions.remove(i) - + # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: cur_chunk_infos = chunk_infos + [region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1] + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]), - }) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + } + ) # no region found if len(regions_dict) == 0: return None - + # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) best_region = regions_dict[best_region_idx]["chunk_info"] return best_region - + def _get_compute_node_num(self, start, end): count = 0 - for i in self.index_tracer.node_list[start: end+1]: + for i in self.index_tracer.node_list[start : end + 1]: if _is_non_compute_node(i): count += 1 return count - + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None @@ -1490,7 +1510,9 @@ def __init__(self, gm) -> None: self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory") + self.chunk_selector = ChunkSelector( + self.index_tracer, self.memory_estimator, stratge="fit_memory" + ) def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1808,10 +1830,11 @@ def _replace_name(context, name_from, name_to): def _replace_reshape_size(context, node_name, reshape_size_dict): if node_name not in reshape_size_dict: return context - for size_name, size_value in reshape_size_dict[node_name].items(): + for size_name, size_value in reshape_size_dict[node_name].items(): context = context.replace(size_name, size_value) return context + def emit_code_with_chunk( body, ckpt_func, @@ -1883,7 +1906,9 @@ def emit_code_with_chunk( body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) - body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size']) + body[-1] = _replace_reshape_size( + body[-1], node.name, chunk_search[region_idx]["reshape_size"] + ) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: From 6be89a3b82d370be152c93dd7277e234e68eaea6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 27 Dec 2022 14:48:25 +0800 Subject: [PATCH 061/209] add chunksize in emit, fix bug in reassgin shape --- chunk_codegen.py | 56 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1255852d777d..470768855779 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -988,6 +988,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} + chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] @@ -998,7 +999,7 @@ def _reassgin_reshape_size(self, chunk_info): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = "chunk_size" + reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape chunk_info["reshape_size"] = reshape_size return chunk_info @@ -1276,7 +1277,6 @@ def estimate_chunk_inference_mem( chunk_within = False chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem - chunk_size = 1 chunk_inputs_names = [] if use_chunk: @@ -1285,12 +1285,14 @@ def estimate_chunk_inference_mem( chunk_ends = [i[1] for i in chunk_regions] chunk_inputs = [i["inputs"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] + chunk_sizes = [ + i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos + ] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor @@ -1306,7 +1308,7 @@ def estimate_chunk_inference_mem( chunk_ratio = self._get_chunk_ratio( node, chunk_node_dim[chunk_region_idx], - chunk_size, + chunk_sizes[chunk_region_idx], ) # if node is placeholder, just add the size of the node @@ -1464,8 +1466,53 @@ def _select_fit_memory_chunk_region( chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) best_region = regions_dict[best_region_idx]["chunk_info"] + + # get max chunk size + best_region = self._get_fit_chunk_size(best_region, chunk_infos) return best_region + def _get_fit_chunk_size(self, chunk_info, chunk_infos): + chunk_size = 1 + chunk_info["chunk_size"] = chunk_size + cur_chunk_max_mem = 0 + # search a region + while cur_chunk_max_mem < self.max_memory: + chunk_size *= 2 + chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + # search exact size + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_info, chunk_infos + ) + return chunk_info + + def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): + if l >= 16: + gap = 4 + else: + gap = 1 + while r >= l + gap: + mid = int(l + (r - l)/2) + chunk_info["chunk_size"] = mid + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + if cur_chunk_max_mem >= self.max_memory: + r = mid - gap + else: + l = mid + gap + return l + def _get_compute_node_num(self, start, end): count = 0 for i in self.index_tracer.node_list[start : end + 1]: @@ -1891,6 +1938,7 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], + chunk_size=chunk_search[region_idx]["chunk_size"] ) ) From a2b4755ce96e2e8dea100bafd7790e22426aa548 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 27 Dec 2022 14:49:52 +0800 Subject: [PATCH 062/209] code style --- chunk_codegen.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 470768855779..3cd10350eaba 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -988,7 +988,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} - chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] + chunk_shape = _get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] @@ -999,7 +1001,9 @@ def _reassgin_reshape_size(self, chunk_info): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) chunk_info["reshape_size"] = reshape_size return chunk_info @@ -1498,7 +1502,7 @@ def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): else: gap = 1 while r >= l + gap: - mid = int(l + (r - l)/2) + mid = int(l + (r - l) / 2) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( @@ -1938,7 +1942,7 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], - chunk_size=chunk_search[region_idx]["chunk_size"] + chunk_search[region_idx]["chunk_size"], ) ) From cb2dd1a10614c21ca78e1c0cea2f6f7aa882e712 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 27 Dec 2022 15:01:58 +0800 Subject: [PATCH 063/209] turn off print mem --- chunk_codegen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 3cd10350eaba..6caed88d84d2 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1750,10 +1750,13 @@ def search_region(self): _, active_node, ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True + self.index_tracer.node_list, chunk_infos ) if self._stop_search(init_mem_peak, mem_peak): break + # self.memory_estimator.estimate_chunk_inference_mem( + # self.index_tracer.node_list, chunk_infos, print_mem=True + # ) return chunk_infos From 69af93107f09db3fb90116144296ebc20adc7b52 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 11:28:25 +0800 Subject: [PATCH 064/209] add evoformer openfold init --- evoformer_openfold/evoformer.py | 59 +++++++++ evoformer_openfold/initializer.py | 29 +++++ evoformer_openfold/kernel.py | 19 +++ evoformer_openfold/msa.py | 95 +++++++++++++++ evoformer_openfold/ops.py | 176 +++++++++++++++++++++++++++ evoformer_openfold/triangle.py | 192 ++++++++++++++++++++++++++++++ 6 files changed, 570 insertions(+) create mode 100644 evoformer_openfold/evoformer.py create mode 100755 evoformer_openfold/initializer.py create mode 100644 evoformer_openfold/kernel.py create mode 100644 evoformer_openfold/msa.py create mode 100755 evoformer_openfold/ops.py create mode 100644 evoformer_openfold/triangle.py diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py new file mode 100644 index 000000000000..cfd2bb2a2529 --- /dev/null +++ b/evoformer_openfold/evoformer.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + +from .msa import MSAStack +from .ops import OutProductMean +from .triangle import PairStack + + +def print_memory(init_mem, text=None): + now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem + max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem + print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) + torch.cuda.reset_peak_memory_stats() + + +class EvoformerBlock(nn.Module): + + def __init__(self, d_node, d_pair): + super(EvoformerBlock, self).__init__() + + self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) + self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) + self.pair_stack = PairStack(d_pair=d_pair) + + def forward(self, node, pair): + node = self.msa_stack(node, pair) + pair = pair + self.communication(node) + pair = self.pair_stack(pair) + return node, pair + + +class Evoformer(nn.Module): + + def __init__(self, d_node, d_pair): + super(Evoformer, self).__init__() + + self.blocks = nn.ModuleList() + for _ in range(1): + self.blocks.append(EvoformerBlock(d_node, d_pair)) + + def forward(self, node, pair): + for b in self.blocks: + node, pair = b(node, pair) + return node, pair + + +def evoformer_tiny(): + return Evoformer(d_node=64, d_pair=32) + + +def evoformer_base(): + return Evoformer(d_node=256, d_pair=128) + + +def evoformer_large(): + return Evoformer(d_node=512, d_pair=256) + + +__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py new file mode 100755 index 000000000000..c6ce0659e597 --- /dev/null +++ b/evoformer_openfold/initializer.py @@ -0,0 +1,29 @@ +import math + +import numpy as np +import torch.nn as nn + + +def glorot_uniform_af(x, gain=1.0): + """ + initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: + In PyTorch: + [feature_out, feature_in, n_head ...] + In Jax: + [... n_head, feature_in, feature_out] + However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: + [feature_in, n_head, feature_out] + + In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors + """ + fan_in, fan_out = x.shape[-2:] + if len(x.shape) > 2: + receptive_field_size = np.prod(x.shape[:-2]) + fan_in *= receptive_field_size + fan_out *= receptive_field_size + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + nn.init.uniform_(x, -dev, dev) + + return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py new file mode 100644 index 000000000000..26ab5dc53261 --- /dev/null +++ b/evoformer_openfold/kernel.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + + +def bias_sigmod_ele(y, bias, z): + return torch.sigmoid(y + bias) * z + + +def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, + residual: torch.Tensor, prob: float) -> torch.Tensor: + out = (x + bias) * F.dropout(dropmask, p=prob, training=False) + out = residual + out + return out + + +def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + dropout_mask: torch.Tensor, Z_raw: torch.Tensor, + prob: float) -> torch.Tensor: + return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py new file mode 100644 index 000000000000..cac456638a55 --- /dev/null +++ b/evoformer_openfold/msa.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add +from .ops import SelfAttention, Transition + + +class MSARowAttentionWithPairBias(nn.Module): + + def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): + super(MSARowAttentionWithPairBias, self).__init__() + self.d_node = d_node + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernormM = LayerNorm(d_node) + self.layernormZ = LayerNorm(d_pair) + + _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) + + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) + + def forward(self, M_raw, Z): + ## Input projections + M = self.layernormM(M_raw) + Z = self.layernormZ(Z) + b = F.linear(Z, self.linear_b_weights) + b = b.permute(0, 3, 1, 2) + # b = rearrange(b, 'b q k h -> b h q k') + + M = self.attention(M, b) + dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) + + return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) + + +class MSAColumnAttention(nn.Module): + + def __init__(self, d_node, c=32, n_head=8): + super(MSAColumnAttention, self).__init__() + self.d_node = d_node + self.c = c + self.n_head = n_head + + self.layernormM = LayerNorm(d_node) + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True) + + def forward(self, M_raw): + M = M_raw.transpose(-2, -3) + M = self.layernormM(M) + + M = self.attention(M) + + M = M.transpose(-2, -3) + return M_raw + M + + +class MSAStack(nn.Module): + + def __init__(self, d_node, d_pair, p_drop=0.15): + super(MSAStack, self).__init__() + + self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, + d_pair=d_pair, + p_drop=p_drop) + + self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) + self.MSATransition = Transition(d=d_node) + + def forward(self, node, pair): + node = self.MSARowAttentionWithPairBias(node, pair) + node = self.MSAColumnAttention(node) + node = self.MSATransition(node) + + return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py new file mode 100755 index 000000000000..611b7b0fe777 --- /dev/null +++ b/evoformer_openfold/ops.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .initializer import glorot_uniform_af +from .kernel import bias_sigmod_ele + + +class DropoutRowwise(nn.Module): + + def __init__(self, p): + super(DropoutRowwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, 0:1, :, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class DropoutColumnwise(nn.Module): + + def __init__(self, p): + super(DropoutColumnwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, :, 0:1, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class Transition(nn.Module): + + def __init__(self, d, n=4): + super(Transition, self).__init__() + self.norm = LayerNorm(d) + self.linear1 = Linear(d, n * d, initializer='relu') + self.linear2 = Linear(n * d, d, initializer='zeros') + + def forward(self, src): + x = self.norm(src) + x = self.linear2(F.relu(self.linear1(x))) + return src + x + + +class OutProductMean(nn.Module): + + def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): + super(OutProductMean, self).__init__() + + self.layernormM = LayerNorm(n_feat) + self.linear_a = Linear(n_feat, n_feat_proj) + self.linear_b = Linear(n_feat, n_feat_proj) + + self.o_linear = Linear(n_feat_proj * n_feat_proj, + n_feat_out, + initializer='zero', + use_bias=True) + + def forward(self, M): + M = self.layernormM(M) + left_act = self.linear_a(M) + right_act = self.linear_b(M) + + O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + # O = rearrange(O, 'b i j d e -> b i j (d e)') + O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) + Z = self.o_linear(O) + + return Z + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + feature_in: int, + feature_out: int, + initializer: str = 'linear', + use_bias: bool = True, + bias_init: float = 0., + ): + super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) + + self.use_bias = use_bias + if initializer == 'linear': + glorot_uniform_af(self.weight, gain=1.0) + elif initializer == 'relu': + glorot_uniform_af(self.weight, gain=2.0) + elif initializer == 'zeros': + nn.init.zeros_(self.weight) + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(bias_init) + + +class SelfAttention(nn.Module): + """ + Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors + """ + + def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): + super(SelfAttention, self).__init__() + self.qkv_dim = qkv_dim + self.c = c + self.n_head = n_head + self.out_dim = out_dim + self.gating = gating + self.last_bias_fuse = last_bias_fuse + + self.scaling = self.c**(-0.5) + + # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') + self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + + if gating: + self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) + self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) + + self.o_linear = Linear(n_head * c, + out_dim, + initializer='zero', + use_bias=(not last_bias_fuse)) + + def forward(self, in_data, nonbatched_bias=None): + """ + :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] + :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] + :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] + """ + + # qkv = self.to_qkv(in_data).chunk(3, dim=-1) + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) + + q = self.to_q(in_data) + k = self.to_k(in_data) + v = self.to_v(in_data) + + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), + # [q, k, v]) + q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), + [q, k, v]) + + q = q * self.scaling + + logits = torch.matmul(q, k.transpose(-1, -2)) + + if nonbatched_bias is not None: + logits += nonbatched_bias.unsqueeze(1) + weights = torch.softmax(logits, dim=-1) + # weights = softmax(logits) + + weighted_avg = torch.matmul(weights, v) + # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') + weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) + + if self.gating: + gate_values = self.gating_linear(in_data) + weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) + + output = self.o_linear(weighted_avg) + return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py new file mode 100644 index 000000000000..f479469c3836 --- /dev/null +++ b/evoformer_openfold/triangle.py @@ -0,0 +1,192 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add, bias_ele_dropout_residual +from .ops import Linear, SelfAttention, Transition + + +def permute_final_dims(tensor, inds): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +class TriangleMultiplicationOutgoing(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationOutgoing, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 0, 1)), + # permute_final_dims(right_proj_act, (2, 1, 0)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleMultiplicationIncoming(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationIncoming, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 1, 0)), + # permute_final_dims(right_proj_act, (2, 0, 1)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleAttentionStartingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionStartingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class TriangleAttentionEndingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionEndingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = Z_raw.transpose(-2, -3) + Z = self.layernorm1(Z) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + Z = Z.transpose(-2, -3) + dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class PairStack(nn.Module): + + def __init__(self, d_pair, p_drop=0.25): + super(PairStack, self).__init__() + + self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) + self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) + self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) + self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) + self.PairTransition = Transition(d=d_pair) + + def forward(self, pair): + pair = self.TriangleMultiplicationOutgoing(pair) + pair = self.TriangleMultiplicationIncoming(pair) + pair = self.TriangleAttentionStartingNode(pair) + pair = self.TriangleAttentionEndingNode(pair) + pair = self.PairTransition(pair) + return pair From fff493c2021a55754d574cc1457cb4c695e30354 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 11:48:11 +0800 Subject: [PATCH 065/209] init openfold --- evoformer_openfold/evoformer.py | 59 -- evoformer_openfold/initializer.py | 29 - evoformer_openfold/kernel.py | 19 - evoformer_openfold/msa.py | 95 --- evoformer_openfold/ops.py | 176 ----- evoformer_openfold/triangle.py | 192 ------ openfold/checkpointing.py | 84 +++ openfold/dropout.py | 78 +++ openfold/evoformer.py | 636 +++++++++++++++++++ openfold/msa.py | 392 ++++++++++++ openfold/outer_product_mean.py | 129 ++++ openfold/pair_transition.py | 99 +++ openfold/primitives.py | 529 +++++++++++++++ openfold/tensor_utils.py | 408 ++++++++++++ openfold/triangular_attention.py | 139 ++++ openfold/triangular_multiplicative_update.py | 127 ++++ 16 files changed, 2621 insertions(+), 570 deletions(-) delete mode 100644 evoformer_openfold/evoformer.py delete mode 100755 evoformer_openfold/initializer.py delete mode 100644 evoformer_openfold/kernel.py delete mode 100644 evoformer_openfold/msa.py delete mode 100755 evoformer_openfold/ops.py delete mode 100644 evoformer_openfold/triangle.py create mode 100644 openfold/checkpointing.py create mode 100644 openfold/dropout.py create mode 100644 openfold/evoformer.py create mode 100644 openfold/msa.py create mode 100644 openfold/outer_product_mean.py create mode 100644 openfold/pair_transition.py create mode 100644 openfold/primitives.py create mode 100644 openfold/tensor_utils.py create mode 100644 openfold/triangular_attention.py create mode 100644 openfold/triangular_multiplicative_update.py diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py deleted file mode 100644 index cfd2bb2a2529..000000000000 --- a/evoformer_openfold/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py deleted file mode 100755 index c6ce0659e597..000000000000 --- a/evoformer_openfold/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py deleted file mode 100644 index 26ab5dc53261..000000000000 --- a/evoformer_openfold/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py deleted file mode 100644 index cac456638a55..000000000000 --- a/evoformer_openfold/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py deleted file mode 100755 index 611b7b0fe777..000000000000 --- a/evoformer_openfold/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py deleted file mode 100644 index f479469c3836..000000000000 --- a/evoformer_openfold/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/openfold/checkpointing.py b/openfold/checkpointing.py new file mode 100644 index 000000000000..83e77c638ec1 --- /dev/null +++ b/openfold/checkpointing.py @@ -0,0 +1,84 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint +from typing import Any, Tuple, List, Callable, Optional + + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +def get_checkpoint_fn(): + checkpoint = torch.utils.checkpoint.checkpoint + + return checkpoint + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None: + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + checkpoint = get_checkpoint_fn() + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = checkpoint(chunker(s, e), *args) + args = wrap(args) + + return args diff --git a/openfold/dropout.py b/openfold/dropout.py new file mode 100644 index 000000000000..651b9775ef44 --- /dev/null +++ b/openfold/dropout.py @@ -0,0 +1,78 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from functools import partialmethod +from typing import Union, List + + +class Dropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask + along a particular dimension. + + If not in training mode, this module computes the identity function. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + """ + Args: + r: + Dropout rate + batch_dim: + Dimension(s) along which the dropout mask is shared + """ + super(Dropout, self).__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Tensor to which dropout is applied. Can have any shape + compatible with self.batch_dim + """ + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + mask = x.new_ones(shape) + mask = self.dropout(mask) + x *= mask + return x + + +class DropoutRowwise(Dropout): + """ + Convenience class for rowwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-3) + + +class DropoutColumnwise(Dropout): + """ + Convenience class for columnwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/openfold/evoformer.py b/openfold/evoformer.py new file mode 100644 index 000000000000..21e422b04764 --- /dev/null +++ b/openfold/evoformer.py @@ -0,0 +1,636 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Tuple, Optional +from functools import partial + +from openfold.primitives import Linear, LayerNorm +from openfold.dropout import DropoutRowwise, DropoutColumnwise +from openfold.msa import ( + MSARowAttentionWithPairBias, + MSAColumnAttention, + MSAColumnGlobalAttention, +) +from openfold.outer_product_mean import OuterProductMean +from openfold.pair_transition import PairTransition +from openfold.triangular_attention import ( + TriangleAttentionStartingNode, + TriangleAttentionEndingNode, +) +from openfold.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) +from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn +from openfold.tensor_utils import chunk_layer + + +class MSATransition(nn.Module): + """ + Feed-forward network applied to MSA activations after attention. + + Implements Algorithm 9 + """ + def __init__(self, c_m, n): + """ + Args: + c_m: + MSA channel dimension + n: + Factor multiplied to c_m to obtain the hidden channel + dimension + """ + super(MSATransition, self).__init__() + + self.c_m = c_m + self.n = n + + self.layer_norm = LayerNorm(self.c_m) + self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") + + def _transition(self, m, mask): + m = self.linear_1(m) + m = self.relu(m) + m = self.linear_2(m) * mask + return m + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"m": m, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA activation + mask: + [*, N_seq, N_res, C_m] MSA mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA activation update + """ + + # DISCREPANCY: DeepMind forgets to apply the MSA mask here. + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, 1] + mask = mask.unsqueeze(-1) + + m = self.layer_norm(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._transition(m, mask) + + return m + + +class EvoformerBlockCore(nn.Module): + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + pair_dropout: float, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + is_multimer: bool = False, + ): + super(EvoformerBlockCore, self).__init__() + self.is_multimer = is_multimer + self.msa_transition = MSATransition( + c_m=c_m, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + c_z, + c_hidden_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + c_z, + c_hidden_mul, + ) + + self.tri_att_start = TriangleAttentionStartingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + + self.pair_transition = PairTransition( + c_z, + transition_n, + ) + + self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) + self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # DeepMind doesn't mask these transitions in the source, so _mask_trans + # should be disabled to better approximate the exact activations of + # the original. + msa_trans_mask = msa_mask if _mask_trans else None + pair_trans_mask = pair_mask if _mask_trans else None + + m = m + self.msa_transition( + m, mask=msa_trans_mask, chunk_size=chunk_size + ) + z = z + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size + ) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer( + self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.ps_dropout_col_layer( + self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.pair_transition( + z, mask=pair_trans_mask, chunk_size=chunk_size + ) + + return m, z + + +class EvoformerBlock(nn.Module): + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + is_multimer: bool, + ): + super(EvoformerBlock, self).__init__() + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnAttention( + c_m, + c_hidden_msa_att, + no_heads_msa, + inf=inf, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + self.is_multimer = is_multimer + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) + ) + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + + return m, z + + +class ExtraMSABlock(nn.Module): + """ + Almost identical to the standard EvoformerBlock, except in that the + ExtraMSABlock uses GlobalAttention for MSA column attention and + requires more fine-grained control over checkpointing. Separated from + its twin to preserve the TorchScript-ability of the latter. + """ + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + is_multimer: bool, + ): + super(ExtraMSABlock, self).__init__() + + self.ckpt = ckpt + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnGlobalAttention( + c_in=c_m, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + eps=eps, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + self.is_multimer = is_multimer + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = 1024, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row( + m.clone(), + z=z.clone(), + mask=msa_mask, + chunk_size=chunk_size, + _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, + _checkpoint_chunks= + self.ckpt if torch.is_grad_enabled() else False, + ) + ) + + def fn(m, z): + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size + ) + + return m, z + + if(torch.is_grad_enabled() and self.ckpt): + checkpoint_fn = get_checkpoint_fn() + m, z = checkpoint_fn(fn, m, z) + else: + m, z = fn(m, z) + + return m, z + + +class EvoformerStack(nn.Module): + """ + Main Evoformer trunk. + + Implements Algorithm 6. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + c_s: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + blocks_per_ckpt: int, + inf: float, + eps: float, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair channel dimension + c_hidden_msa_att: + Hidden dimension in MSA attention + c_hidden_opm: + Hidden dimension in outer product mean module + c_hidden_mul: + Hidden dimension in multiplicative updates + c_hidden_pair_att: + Hidden dimension in triangular attention + c_s: + Channel dimension of the output "single" embedding + no_heads_msa: + Number of heads used for MSA attention + no_heads_pair: + Number of heads used for pair attention + no_blocks: + Number of Evoformer blocks in the stack + transition_n: + Factor by which to multiply c_m to obtain the MSATransition + hidden dimension + msa_dropout: + Dropout rate for MSA activations + pair_dropout: + Dropout used for pair activations + blocks_per_ckpt: + Number of Evoformer blocks in each activation checkpoint + clear_cache_between_blocks: + Whether to clear CUDA's GPU memory cache between blocks of the + stack. Slows down each block but can reduce fragmentation + """ + super(EvoformerStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + self.clear_cache_between_blocks = clear_cache_between_blocks + + self.blocks = nn.ModuleList() + + for _ in range(no_blocks): + block = EvoformerBlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + self.linear = Linear(c_m, c_s) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: int, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + [*, N_seq, N_res] MSA mask + pair_mask: + [*, N_res, N_res] pair mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + s: + [*, N_res, C_s] single embedding (or None if extra MSA stack) + """ + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + for b in self.blocks + ] + + if(self.clear_cache_between_blocks): + def block_with_cache_clear(block, *args): + torch.cuda.empty_cache() + return block(*args) + + blocks = [partial(block_with_cache_clear, b) for b in blocks] + + m, z = checkpoint_blocks( + blocks, + args=(m, z), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + s = self.linear(m[..., 0, :, :]) + + return m, z, s + + +class ExtraMSAStack(nn.Module): + """ + Implements Algorithm 18. + """ + + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + super(ExtraMSAStack, self).__init__() + + self.clear_cache_between_blocks = clear_cache_between_blocks + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = ExtraMSABlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ckpt=ckpt, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + chunk_size: int, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + _mask_trans: bool = True, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_extra, N_res, C_m] extra MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + Optional [*, N_extra, N_res] MSA mask + pair_mask: + Optional [*, N_res, N_res] pair mask + Returns: + [*, N_res, N_res, C_z] pair update + """ + #checkpoint_fn = get_checkpoint_fn() + #blocks = [ + # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks + #] + + #def dodo(b, *args): + # torch.cuda.empty_cache() + # return b(*args) + + #blocks = [partial(dodo, b) for b in blocks] + + #for b in blocks: + # if(torch.is_grad_enabled()): + # m, z = checkpoint_fn(b, *(m, z)) + # else: + # m, z = b(m, z) + + for b in self.blocks: + m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) + + if(self.clear_cache_between_blocks): + torch.cuda.empty_cache() + + return z \ No newline at end of file diff --git a/openfold/msa.py b/openfold/msa.py new file mode 100644 index 000000000000..172b26def5f1 --- /dev/null +++ b/openfold/msa.py @@ -0,0 +1,392 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Optional, List, Tuple + +from openfold.primitives import ( + Linear, + LayerNorm, + Attention, + GlobalAttention, + _attention_chunked_trainable, +) +from openfold.checkpointing import get_checkpoint_fn +from openfold.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class MSAAttention(nn.Module): + def __init__( + self, + c_in, + c_hidden, + no_heads, + pair_bias=False, + c_z=None, + inf=1e9, + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + pair_bias: + Whether to use pair embedding bias + c_z: + Pair embedding channel dimension. Ignored unless pair_bias + is true + inf: + A large number to be used in computing the attention mask + """ + super(MSAAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.pair_bias = pair_bias + self.c_z = c_z + self.inf = inf + + self.layer_norm_m = LayerNorm(self.c_in) + + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(self.c_z) + self.linear_z = Linear( + self.c_z, self.no_heads, bias=False, init="normal" + ) + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + {"q_x": m, "kv_x": m, "biases": biases}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def _prep_inputs(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, N_seq, N_res, C_m] + m = self.layer_norm_m(m) + + n_seq, n_res = m.shape[-3:-1] + if mask is None: + # [*, N_seq, N_res] + mask = m.new_ones( + m.shape[:-3] + (n_seq, n_res), + ) + + # [*, N_seq, 1, 1, N_res] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # This step simply returns a larger view of the bias, and does not + # consume additional memory. + # [*, N_seq, no_heads, N_res, N_res] + #bias = bias.expand( + # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) + #) + + if (self.pair_bias and + z is not None and # For the + self.layer_norm_z is not None and # benefit of + self.linear_z is not None # TorchScript + ): + # [*, N_res, N_res, C_z] + z = self.layer_norm_z(z) + + # [*, N_res, N_res, no_heads] + z = self.linear_z(z) + + # [*, 1, no_heads, N_res, N_res] + z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) + + return m, mask_bias, z + + @torch.jit.ignore + def _chunked_msa_attn(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + chunk_logits: int, + checkpoint: bool, + ) -> torch.Tensor: + MSA_DIM = -4 + + def _get_qkv(m, z): + m, mask_bias, z = self._prep_inputs(m, z, mask) + q, k, v = self.mha._prep_qkv(m, m) + return m, q, k, v, mask_bias, z + + checkpoint_fn = get_checkpoint_fn() + + if(torch.is_grad_enabled() and checkpoint): + m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) + else: + m, q, k, v, mask_bias, z = _get_qkv(m, z) + + o = _attention_chunked_trainable( + query=q, + key=k, + value=v, + biases=[mask_bias, z], + chunk_size=chunk_logits, + chunk_dim=MSA_DIM, + checkpoint=checkpoint, + ) + + if(torch.is_grad_enabled() and checkpoint): + # Storing an additional m here is far from ideal + m = checkpoint_fn(self.mha._wrap_up, o, m) + else: + m = self.mha._wrap_up(o, m) + + return m + + def forward(self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = None, + _checkpoint_chunks: Optional[bool] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding. Required only if + pair_bias is True + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + + """ + if(_chunk_logits is not None): + return self._chunked_msa_attn( + m=m, z=z, mask=mask, + chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks + ) + + m, mask_bias, z = self._prep_inputs(m, z, mask) + + biases = [mask_bias] + if(z is not None): + biases.append(z) + + if chunk_size is not None: + m = self._chunk(m, biases, chunk_size) + else: + m = self.mha( + q_x=m, + kv_x=m, + biases=biases + ) + + return m + + +class MSARowAttentionWithPairBias(MSAAttention): + """ + Implements Algorithm 7. + """ + + def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + Input channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSARowAttentionWithPairBias, self).__init__( + c_m, + c_hidden, + no_heads, + pair_bias=True, + c_z=c_z, + inf=inf, + ) + + +class MSAColumnAttention(nn.Module): + """ + Implements Algorithm 8. + + By rights, this should also be a subclass of MSAAttention. Alas, + most inheritance isn't supported by TorchScript. + """ + + def __init__(self, c_m, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + MSA channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSAColumnAttention, self).__init__() + + self.c_m = c_m + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self._msa_att = MSAAttention( + c_in=c_m, + c_hidden=c_hidden, + no_heads=no_heads, + pair_bias=False, + c_z=None, + inf=inf, + ) + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + """ + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.layer_norm_m = nn.LayerNorm(c_in) + + self.global_attention = GlobalAttention( + c_in=c_in, + c_hidden=c_hidden, + no_heads=no_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_input = { + "m": m, + "mask": mask, + } + return chunk_layer( + self.global_attention, + mha_input, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + n_seq, n_res, c_in = m.shape[-3:] + + if mask is None: + # [*, N_seq, N_res] + mask = torch.ones( + m.shape[:-1], + dtype=m.dtype, + device=m.device, + ).detach() + + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, N_res, N_seq, C_in] + m = self.layer_norm_m(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self.global_attention(m=m, mask=mask) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + + return m diff --git a/openfold/outer_product_mean.py b/openfold/outer_product_mean.py new file mode 100644 index 000000000000..43d853833c66 --- /dev/null +++ b/openfold/outer_product_mean.py @@ -0,0 +1,129 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear +from openfold.tensor_utils import chunk_layer + + +class OuterProductMean(nn.Module): + """ + Implements Algorithm 10. + """ + + def __init__(self, c_m, c_z, c_hidden, eps=1e-3): + """ + Args: + c_m: + MSA embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(OuterProductMean, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.c_hidden = c_hidden + self.eps = eps + + self.layer_norm = nn.LayerNorm(c_m) + self.linear_1 = Linear(c_m, c_hidden) + self.linear_2 = Linear(c_m, c_hidden) + self.linear_out = Linear(c_hidden ** 2, c_z, init="final") + + def _opm(self, a, b): + # [*, N_res, N_res, C, C] + outer = torch.einsum("...bac,...dae->...bdce", a, b) + + # [*, N_res, N_res, C * C] + outer = outer.reshape(outer.shape[:-2] + (-1,)) + + # [*, N_res, N_res, C_z] + outer = self.linear_out(outer) + + return outer + + @torch.jit.ignore + def _chunk(self, + a: torch.Tensor, + b: torch.Tensor, + chunk_size: int + ) -> torch.Tensor: + # Since the "batch dim" in this case is not a true batch dimension + # (in that the shape of the output depends on it), we need to + # iterate over it ourselves + a_reshape = a.reshape((-1,) + a.shape[-3:]) + b_reshape = b.reshape((-1,) + b.shape[-3:]) + out = [] + for a_prime, b_prime in zip(a_reshape, b_reshape): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {"a": a_prime}, + chunk_size=chunk_size, + no_batch_dims=1, + ) + out.append(outer) + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, C_m] + m = self.layer_norm(m) + + # [*, N_seq, N_res, C] + mask = mask.unsqueeze(-1) + a = self.linear_1(m) * mask + b = self.linear_2(m) * mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + outer = self._chunk(a, b, chunk_size) + else: + outer = self._opm(a, b) + + # [*, N_res, N_res, 1] + norm = torch.einsum("...abc,...adc->...bdc", mask, mask) + + # [*, N_res, N_res, C_z] + outer = outer / (self.eps + norm) + + return outer diff --git a/openfold/pair_transition.py b/openfold/pair_transition.py new file mode 100644 index 000000000000..de76306418ee --- /dev/null +++ b/openfold/pair_transition.py @@ -0,0 +1,99 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm +from openfold.tensor_utils import chunk_layer + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) * mask + + return z + + @torch.jit.ignore + def _chunk(self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"z": z, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + if chunk_size is not None: + z = self._chunk(z, mask, chunk_size) + else: + z = self._transition(z=z, mask=mask) + + return z diff --git a/openfold/primitives.py b/openfold/primitives.py new file mode 100644 index 000000000000..bbc156f21d4a --- /dev/null +++ b/openfold/primitives.py @@ -0,0 +1,529 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np + +import torch +import torch.nn as nn + +from openfold.checkpointing import get_checkpoint_fn +from openfold.tensor_utils import ( + permute_final_dims, + flatten_final_dims, + _chunk_slice, +) + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + normal_init_(self.weight) + elif init == "relu": + normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class LayerNorm(nn.Module): + + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + + +@torch.jit.ignore +def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, Q, C_hidden] + query = permute_final_dims(query, (1, 0, 2)) + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 2, 0)) + + # [*, H, V, C_hidden] + value = permute_final_dims(value, (1, 0, 2)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax(a, -1) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + + # [*, Q, H, C_hidden] + a = a.transpose(-2, -3) + + return a + + +@torch.jit.ignore +def _attention_chunked_trainable( + query, + key, + value, + biases, + chunk_size, + chunk_dim, + checkpoint, +): + if (checkpoint and len(biases) > 2): + raise ValueError("Checkpointed version permits only permits two bias terms") + + def _checkpointable_attention(q, k, v, b1, b2): + bs = [b for b in [b1, b2] if b is not None] + return _attention(q, k, v, bs) + + o_chunks = [] + checkpoint_fn = get_checkpoint_fn() + count = query.shape[chunk_dim] + for start in range(0, count, chunk_size): + end = start + chunk_size + idx = [slice(None)] * len(query.shape) + idx[chunk_dim] = slice(start, end) + idx_tup = tuple(idx) + q_chunk = query[idx_tup] + k_chunk = key[idx_tup] + v_chunk = value[idx_tup] + + def _slice_bias(b): + idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) + return b[tuple(idx)] + + if (checkpoint): + bias_1_chunk, bias_2_chunk = [ + _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] + ] + + o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, + bias_1_chunk, bias_2_chunk) + else: + bias_chunks = [_slice_bias(b) for b in biases] + + o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) + + o_chunks.append(o_chunk) + + o = torch.cat(o_chunks, dim=chunk_dim) + return o + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, + kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if (self.linear_g is not None): + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_lma: bool = False, + q_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_lma: + Whether to use low-memory attention + q_chunk_size: + Query chunk size (for LMA) + kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if (biases is None): + biases = [] + if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): + raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " + "be provided") + + q, k, v = self._prep_qkv(q_x, kv_x) + + if (use_lma): + biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] + + o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) + else: + o = _attention(q, k, v, biases) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") + + self.linear_k = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_v = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), + dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden**(-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + bias = (self.inf * (mask - 1))[..., :, None, :] + a += bias + a = softmax(a) + + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, + v, + ) + + # [*, N_res, N_seq, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, N_seq, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, N_seq, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, N_seq, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-3], k.shape[-3] + + # [*, Q, H, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] + large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] + v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] + small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] + + a = torch.einsum( + "...qhd,...khd->...hqk", + q_chunk, + k_chunk, + ) + + for b in small_bias_chunks: + a += b + + a = a.transpose(-2, -3) + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= max_diffs.unsqueeze(-1) + chunk_weights *= max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out + + return o diff --git a/openfold/tensor_utils.py b/openfold/tensor_utils.py new file mode 100644 index 000000000000..7e5e8e4b6b5e --- /dev/null +++ b/openfold/tensor_utils.py @@ -0,0 +1,408 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn +from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in + sequence on a tensor with shape dims, yields tensors that contain every + leaf in the contiguous range [start, end]. Care is taken to yield a + short sequence of slices, and perhaps even the shortest possible (I'm + pretty sure it's the latter). + + end is INCLUSIVE. + """ + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if(start_edges is None): + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if(end_edges is None): + end_edges = [e == (d - 1) for e,d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if(len(start) == 0): + return [tuple()] + elif(len(start) == 1): + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s,e in zip(start, end): + if(s == e): + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if(divergence_idx == len(dims)): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s for s in + _get_minimal_slice_set( + start[divergence_idx + 1:], + [d - 1 for d in dims[divergence_idx + 1:]], + dims[divergence_idx + 1:], + start_edges=start_edges[divergence_idx + 1:], + end_edges=[1 for _ in end_edges[divergence_idx + 1:]] + ) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s for s in + _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1:]], + end[divergence_idx + 1:], + dims[divergence_idx + 1:], + start_edges=[1 for _ in start_edges[divergence_idx + 1:]], + end_edges=end_edges[divergence_idx + 1:], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if(start_edges[divergence_idx] and end_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx] + 1),) + ) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif(start_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx]),) + ) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif(end_edges[divergence_idx]): + slices.extend(upper()) + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) + ) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if(middle_ground > 1): + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) + ) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be + memory-intensive in certain situations. The only reshape operations + in this function are performed on sub-tensors that scale with + (flat_end - flat_start), the chunk size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat( + [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] + ) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," + consisting only of (arbitrarily nested) lists, tuples, and dicts with + torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must + be tensors and must share the same batch dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch + dimensions are specified, a "sub-batch" is defined as a single + indexing of all batch dimensions simultaneously (s.t. the + number of sub-batches is the product of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can + be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary + in most cases, and is ever so slightly slower than the default + setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + # TODO: make this more memory efficient. This sucks + if(not low_mem): + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + ( + flat_batch_dim % chunk_size != 0 + ) + + i = 0 + out = None + for _ in range(no_chunks): + # Chunk the input + if(not low_mem): + select_chunk = ( + lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t + ) + else: + select_chunk = ( + partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims) + ) + ) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) + out = tensor_tree_map(allocate, output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + x1[i : i + chunk_size] = x2 + elif out_type is torch.Tensor: + out[i : i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + out = tensor_tree_map(reshape, out) + + return out diff --git a/openfold/triangular_attention.py b/openfold/triangular_attention.py new file mode 100644 index 000000000000..6d3e37f4c681 --- /dev/null +++ b/openfold/triangular_attention.py @@ -0,0 +1,139 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod, partial +import math +from typing import Optional, List + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm, Attention +from openfold.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class TriangleAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, starting, inf=1e9 + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super(TriangleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + return chunk_layer( + partial(self.mha), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + ) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + # Shape annotations assume self.starting. Else, I and J are flipped + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk(x, biases, chunk_size) + else: + x = self.mha(q_x=x, kv_x=x, biases=biases) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class TriangleAttentionStartingNode(TriangleAttention): + """ + Implements Algorithm 13. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEndingNode(TriangleAttention): + """ + Implements Algorithm 14. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/openfold/triangular_multiplicative_update.py b/openfold/triangular_multiplicative_update.py new file mode 100644 index 000000000000..2406e2bac2cf --- /dev/null +++ b/openfold/triangular_multiplicative_update.py @@ -0,0 +1,127 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm +from openfold.tensor_utils import permute_final_dims + + +class TriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + def __init__(self, c_z, c_hidden, _outgoing=True): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(TriangleMultiplicativeUpdate, self).__init__() + self.c_z = c_z + self.c_hidden = c_hidden + self._outgoing = _outgoing + + self.linear_a_p = Linear(self.c_z, self.c_hidden) + self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_b_p = Linear(self.c_z, self.c_hidden) + self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_g = Linear(self.c_z, self.c_z, init="gating") + self.linear_z = Linear(self.c_hidden, self.c_z, init="final") + + self.layer_norm_in = LayerNorm(self.c_z) + self.layer_norm_out = LayerNorm(self.c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections(self, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("This method needs to be overridden") + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) + a = a * mask + b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) + b = b * mask + x = self._combine_projections(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + z = x * g + + return z + + +class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 11. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_i, N_k, C] + b: torch.Tensor, # [*, N_j, N_k, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 0, 1)), + permute_final_dims(b, (2, 1, 0)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + + +class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 12. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_k, N_i, C] + b: torch.Tensor, # [*, N_k, N_j, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 1, 0)), + permute_final_dims(b, (2, 0, 1)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + From 1d7ca02301c9ff71953070ea963b8e107fa4ccb6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:28:38 +0800 Subject: [PATCH 066/209] add benchmark --- autochunk_benchmark.py | 79 ++++++++++++++++++++++++++++++++++++++++++ chunk_codegen.py | 16 +++++---- 2 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 autochunk_benchmark.py diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py new file mode 100644 index 000000000000..a34464212e02 --- /dev/null +++ b/autochunk_benchmark.py @@ -0,0 +1,79 @@ +import copy +import torch +import torch.nn.functional as F +import pytest +import torch.fx +import torch.multiprocessing as mp +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.profiler import MetaTensor +from evoformer.evoformer import evoformer_base +from chunk_codegen import ChunkCodeGen +import time + + +def _benchmark_evoformer(model: torch.nn.Module, node, pair): + loop = 10 + with torch.no_grad(): + for _ in range(loop // 4): + model(node, pair) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(node, pair) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_evoformer(): + # data + msa_len = 300 + pair_len = 800 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + + # build gm model + max_memory = 3000 # MB + model = evoformer_base().cuda() + # trace the module and replace codegen + graph = ColoTracer().trace( + model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }, + ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) + # now run it twice to get meta info in graph module, not necessary + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) + # set code_gen + codegen = ChunkCodeGen(gm_prop, max_memory) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) + gm.recompile() + # print + code = graph.python_code("self").src + print(code) + + time_gm = _benchmark_evoformer(gm, node, pair) + print("gm %.4fs" % time_gm) + time_openfold = _benchmark_evoformer(model, node, pair) + print("openfold %.4fs" % time_openfold) + + +if __name__ == "__main__": + benchmark_evoformer() diff --git a/chunk_codegen.py b/chunk_codegen.py index 6caed88d84d2..033db50dbccb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1398,13 +1398,14 @@ def estimate_chunk_inference_mem( class ChunkSelector(object): def __init__( - self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge + self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge, max_memory=None ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator assert stratge in ["min_memory", "fit_memory"] + assert (stratge == "fit_memory" and max_memory is not None) or stratge != "fit_memory" self.stratge = stratge - self.max_memory = 600 # MB + self.max_memory = max_memory # MB def _select_best_chunk_region( self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak @@ -1556,13 +1557,13 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos): class ChunkRegionSearch(object): - def __init__(self, gm) -> None: + def __init__(self, gm, max_memory=None) -> None: self.gm = gm self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, stratge="fit_memory" + self.index_tracer, self.memory_estimator, stratge="fit_memory", max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -1897,6 +1898,7 @@ def emit_code_with_chunk( delete_unused_value_func, meta_nodes, meta_graph, + max_memory=None, ): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use @@ -1912,7 +1914,7 @@ def emit_code_with_chunk( node_list = list(nodes) # find the chunk regions - chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) chunk_search = chunk_region_search.search_region() chunk_regions = [i["region"] for i in chunk_search] @@ -1989,9 +1991,10 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class ChunkCodeGen(CodeGen): - def __init__(self, meta_graph): + def __init__(self, meta_graph, max_memory=None): super().__init__() self.meta_graph = meta_graph + self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code( @@ -2230,6 +2233,7 @@ def emit_node(node: Node, body): delete_unused_values, self.meta_node, self.meta_graph, + self.max_memory ) if len(body) == 0: From 5a916c0adb320b4a1cfc96e8a40364fb62a0a463 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:42:29 +0800 Subject: [PATCH 067/209] add print --- autochunk_benchmark.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index a34464212e02..0c55a3a8848c 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -1,24 +1,21 @@ -import copy +import time + import torch -import torch.nn.functional as F -import pytest import torch.fx -import torch.multiprocessing as mp -from torch.fx import GraphModule + +from chunk_codegen import ChunkCodeGen from colossalai.fx import ColoTracer -import colossalai -from colossalai.utils import free_port -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from evoformer.evoformer import evoformer_base -from chunk_codegen import ChunkCodeGen -import time -def _benchmark_evoformer(model: torch.nn.Module, node, pair): - loop = 10 +def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + + loop = 16 with torch.no_grad(): for _ in range(loop // 4): model(node, pair) @@ -28,7 +25,12 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair): model(node, pair) torch.cuda.synchronize() time2 = time.time() - return (time2 - time1) / loop + + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print( + "%s: time %.4fs, mem %dMB" + % (title, (time2 - time1) / loop, new_max_mem - now_mem) + ) def benchmark_evoformer(): @@ -69,10 +71,8 @@ def benchmark_evoformer(): code = graph.python_code("self").src print(code) - time_gm = _benchmark_evoformer(gm, node, pair) - print("gm %.4fs" % time_gm) - time_openfold = _benchmark_evoformer(model, node, pair) - print("openfold %.4fs" % time_openfold) + _benchmark_evoformer(gm, node, pair, "autochunk") + _benchmark_evoformer(model, node, pair, "openfold") if __name__ == "__main__": From 7a23deb58455b112cf187776857e2a262d0b737e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:47:16 +0800 Subject: [PATCH 068/209] code style --- autochunk_benchmark.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 0c55a3a8848c..f8e603f4ee63 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): def benchmark_evoformer(): - # data + # init data and model msa_len = 300 pair_len = 800 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() + model = evoformer_base().cuda() - # build gm model + # build autochunk model max_memory = 3000 # MB - model = evoformer_base().cuda() + autochunk = _build_autochunk(model, max_memory, node, pair) + + # benchmark + _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(autochunk, node, pair, "autochunk") + + +def _build_autochunk(model, max_memory, node, pair): # trace the module and replace codegen graph = ColoTracer().trace( model, @@ -70,9 +78,7 @@ def benchmark_evoformer(): # print code = graph.python_code("self").src print(code) - - _benchmark_evoformer(gm, node, pair, "autochunk") - _benchmark_evoformer(model, node, pair, "openfold") + return gm if __name__ == "__main__": From efe6fe3a33c4b8c50c2e964188fef72d1f269cfd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:47:47 +0800 Subject: [PATCH 069/209] code style --- autochunk_benchmark.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index f8e603f4ee63..20f615b216f7 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -33,23 +33,6 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): ) -def benchmark_evoformer(): - # init data and model - msa_len = 300 - pair_len = 800 - node = torch.randn(1, msa_len, pair_len, 256).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - model = evoformer_base().cuda() - - # build autochunk model - max_memory = 3000 # MB - autochunk = _build_autochunk(model, max_memory, node, pair) - - # benchmark - _benchmark_evoformer(model, node, pair, "openfold") - _benchmark_evoformer(autochunk, node, pair, "autochunk") - - def _build_autochunk(model, max_memory, node, pair): # trace the module and replace codegen graph = ColoTracer().trace( @@ -81,5 +64,22 @@ def _build_autochunk(model, max_memory, node, pair): return gm +def benchmark_evoformer(): + # init data and model + msa_len = 300 + pair_len = 800 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + model = evoformer_base().cuda() + + # build autochunk model + max_memory = 3000 # MB + autochunk = _build_autochunk(model, max_memory, node, pair) + + # benchmark + _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(autochunk, node, pair, "autochunk") + + if __name__ == "__main__": benchmark_evoformer() From 289f3a45c24233fec28af6d5651b3099b55ace8b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 15:01:15 +0800 Subject: [PATCH 070/209] init openfold --- evoformer_openfold/evoformer.py | 59 +++++++++ evoformer_openfold/initializer.py | 29 +++++ evoformer_openfold/kernel.py | 19 +++ evoformer_openfold/msa.py | 95 +++++++++++++++ evoformer_openfold/ops.py | 176 +++++++++++++++++++++++++++ evoformer_openfold/triangle.py | 192 ++++++++++++++++++++++++++++++ 6 files changed, 570 insertions(+) create mode 100644 evoformer_openfold/evoformer.py create mode 100755 evoformer_openfold/initializer.py create mode 100644 evoformer_openfold/kernel.py create mode 100644 evoformer_openfold/msa.py create mode 100755 evoformer_openfold/ops.py create mode 100644 evoformer_openfold/triangle.py diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py new file mode 100644 index 000000000000..cfd2bb2a2529 --- /dev/null +++ b/evoformer_openfold/evoformer.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + +from .msa import MSAStack +from .ops import OutProductMean +from .triangle import PairStack + + +def print_memory(init_mem, text=None): + now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem + max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem + print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) + torch.cuda.reset_peak_memory_stats() + + +class EvoformerBlock(nn.Module): + + def __init__(self, d_node, d_pair): + super(EvoformerBlock, self).__init__() + + self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) + self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) + self.pair_stack = PairStack(d_pair=d_pair) + + def forward(self, node, pair): + node = self.msa_stack(node, pair) + pair = pair + self.communication(node) + pair = self.pair_stack(pair) + return node, pair + + +class Evoformer(nn.Module): + + def __init__(self, d_node, d_pair): + super(Evoformer, self).__init__() + + self.blocks = nn.ModuleList() + for _ in range(1): + self.blocks.append(EvoformerBlock(d_node, d_pair)) + + def forward(self, node, pair): + for b in self.blocks: + node, pair = b(node, pair) + return node, pair + + +def evoformer_tiny(): + return Evoformer(d_node=64, d_pair=32) + + +def evoformer_base(): + return Evoformer(d_node=256, d_pair=128) + + +def evoformer_large(): + return Evoformer(d_node=512, d_pair=256) + + +__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py new file mode 100755 index 000000000000..c6ce0659e597 --- /dev/null +++ b/evoformer_openfold/initializer.py @@ -0,0 +1,29 @@ +import math + +import numpy as np +import torch.nn as nn + + +def glorot_uniform_af(x, gain=1.0): + """ + initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: + In PyTorch: + [feature_out, feature_in, n_head ...] + In Jax: + [... n_head, feature_in, feature_out] + However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: + [feature_in, n_head, feature_out] + + In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors + """ + fan_in, fan_out = x.shape[-2:] + if len(x.shape) > 2: + receptive_field_size = np.prod(x.shape[:-2]) + fan_in *= receptive_field_size + fan_out *= receptive_field_size + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + nn.init.uniform_(x, -dev, dev) + + return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py new file mode 100644 index 000000000000..26ab5dc53261 --- /dev/null +++ b/evoformer_openfold/kernel.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + + +def bias_sigmod_ele(y, bias, z): + return torch.sigmoid(y + bias) * z + + +def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, + residual: torch.Tensor, prob: float) -> torch.Tensor: + out = (x + bias) * F.dropout(dropmask, p=prob, training=False) + out = residual + out + return out + + +def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + dropout_mask: torch.Tensor, Z_raw: torch.Tensor, + prob: float) -> torch.Tensor: + return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py new file mode 100644 index 000000000000..cac456638a55 --- /dev/null +++ b/evoformer_openfold/msa.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add +from .ops import SelfAttention, Transition + + +class MSARowAttentionWithPairBias(nn.Module): + + def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): + super(MSARowAttentionWithPairBias, self).__init__() + self.d_node = d_node + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernormM = LayerNorm(d_node) + self.layernormZ = LayerNorm(d_pair) + + _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) + + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) + + def forward(self, M_raw, Z): + ## Input projections + M = self.layernormM(M_raw) + Z = self.layernormZ(Z) + b = F.linear(Z, self.linear_b_weights) + b = b.permute(0, 3, 1, 2) + # b = rearrange(b, 'b q k h -> b h q k') + + M = self.attention(M, b) + dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) + + return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) + + +class MSAColumnAttention(nn.Module): + + def __init__(self, d_node, c=32, n_head=8): + super(MSAColumnAttention, self).__init__() + self.d_node = d_node + self.c = c + self.n_head = n_head + + self.layernormM = LayerNorm(d_node) + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True) + + def forward(self, M_raw): + M = M_raw.transpose(-2, -3) + M = self.layernormM(M) + + M = self.attention(M) + + M = M.transpose(-2, -3) + return M_raw + M + + +class MSAStack(nn.Module): + + def __init__(self, d_node, d_pair, p_drop=0.15): + super(MSAStack, self).__init__() + + self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, + d_pair=d_pair, + p_drop=p_drop) + + self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) + self.MSATransition = Transition(d=d_node) + + def forward(self, node, pair): + node = self.MSARowAttentionWithPairBias(node, pair) + node = self.MSAColumnAttention(node) + node = self.MSATransition(node) + + return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py new file mode 100755 index 000000000000..611b7b0fe777 --- /dev/null +++ b/evoformer_openfold/ops.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .initializer import glorot_uniform_af +from .kernel import bias_sigmod_ele + + +class DropoutRowwise(nn.Module): + + def __init__(self, p): + super(DropoutRowwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, 0:1, :, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class DropoutColumnwise(nn.Module): + + def __init__(self, p): + super(DropoutColumnwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, :, 0:1, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class Transition(nn.Module): + + def __init__(self, d, n=4): + super(Transition, self).__init__() + self.norm = LayerNorm(d) + self.linear1 = Linear(d, n * d, initializer='relu') + self.linear2 = Linear(n * d, d, initializer='zeros') + + def forward(self, src): + x = self.norm(src) + x = self.linear2(F.relu(self.linear1(x))) + return src + x + + +class OutProductMean(nn.Module): + + def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): + super(OutProductMean, self).__init__() + + self.layernormM = LayerNorm(n_feat) + self.linear_a = Linear(n_feat, n_feat_proj) + self.linear_b = Linear(n_feat, n_feat_proj) + + self.o_linear = Linear(n_feat_proj * n_feat_proj, + n_feat_out, + initializer='zero', + use_bias=True) + + def forward(self, M): + M = self.layernormM(M) + left_act = self.linear_a(M) + right_act = self.linear_b(M) + + O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + # O = rearrange(O, 'b i j d e -> b i j (d e)') + O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) + Z = self.o_linear(O) + + return Z + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + feature_in: int, + feature_out: int, + initializer: str = 'linear', + use_bias: bool = True, + bias_init: float = 0., + ): + super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) + + self.use_bias = use_bias + if initializer == 'linear': + glorot_uniform_af(self.weight, gain=1.0) + elif initializer == 'relu': + glorot_uniform_af(self.weight, gain=2.0) + elif initializer == 'zeros': + nn.init.zeros_(self.weight) + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(bias_init) + + +class SelfAttention(nn.Module): + """ + Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors + """ + + def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): + super(SelfAttention, self).__init__() + self.qkv_dim = qkv_dim + self.c = c + self.n_head = n_head + self.out_dim = out_dim + self.gating = gating + self.last_bias_fuse = last_bias_fuse + + self.scaling = self.c**(-0.5) + + # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') + self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + + if gating: + self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) + self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) + + self.o_linear = Linear(n_head * c, + out_dim, + initializer='zero', + use_bias=(not last_bias_fuse)) + + def forward(self, in_data, nonbatched_bias=None): + """ + :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] + :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] + :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] + """ + + # qkv = self.to_qkv(in_data).chunk(3, dim=-1) + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) + + q = self.to_q(in_data) + k = self.to_k(in_data) + v = self.to_v(in_data) + + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), + # [q, k, v]) + q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), + [q, k, v]) + + q = q * self.scaling + + logits = torch.matmul(q, k.transpose(-1, -2)) + + if nonbatched_bias is not None: + logits += nonbatched_bias.unsqueeze(1) + weights = torch.softmax(logits, dim=-1) + # weights = softmax(logits) + + weighted_avg = torch.matmul(weights, v) + # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') + weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) + + if self.gating: + gate_values = self.gating_linear(in_data) + weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) + + output = self.o_linear(weighted_avg) + return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py new file mode 100644 index 000000000000..f479469c3836 --- /dev/null +++ b/evoformer_openfold/triangle.py @@ -0,0 +1,192 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add, bias_ele_dropout_residual +from .ops import Linear, SelfAttention, Transition + + +def permute_final_dims(tensor, inds): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +class TriangleMultiplicationOutgoing(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationOutgoing, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 0, 1)), + # permute_final_dims(right_proj_act, (2, 1, 0)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleMultiplicationIncoming(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationIncoming, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 1, 0)), + # permute_final_dims(right_proj_act, (2, 0, 1)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleAttentionStartingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionStartingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class TriangleAttentionEndingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionEndingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = Z_raw.transpose(-2, -3) + Z = self.layernorm1(Z) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + Z = Z.transpose(-2, -3) + dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class PairStack(nn.Module): + + def __init__(self, d_pair, p_drop=0.25): + super(PairStack, self).__init__() + + self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) + self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) + self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) + self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) + self.PairTransition = Transition(d=d_pair) + + def forward(self, pair): + pair = self.TriangleMultiplicationOutgoing(pair) + pair = self.TriangleMultiplicationIncoming(pair) + pair = self.TriangleAttentionStartingNode(pair) + pair = self.TriangleAttentionEndingNode(pair) + pair = self.PairTransition(pair) + return pair From 5c4df01af3076069867a66c5fc7a8086e6c55c0a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 15:54:08 +0800 Subject: [PATCH 071/209] update openfold --- openfold/evoformer.py | 29 ++++++------------- openfold/msa.py | 67 ++----------------------------------------- 2 files changed, 12 insertions(+), 84 deletions(-) diff --git a/openfold/evoformer.py b/openfold/evoformer.py index 21e422b04764..7fbcd8a76b4d 100644 --- a/openfold/evoformer.py +++ b/openfold/evoformer.py @@ -182,33 +182,28 @@ def forward( self, m: torch.Tensor, z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, chunk_size: Optional[int] = None, - _mask_trans: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: # DeepMind doesn't mask these transitions in the source, so _mask_trans # should be disabled to better approximate the exact activations of # the original. - msa_trans_mask = msa_mask if _mask_trans else None - pair_trans_mask = pair_mask if _mask_trans else None m = m + self.msa_transition( - m, mask=msa_trans_mask, chunk_size=chunk_size + m, chunk_size=chunk_size ) z = z + self.outer_product_mean( - m, mask=msa_mask, chunk_size=chunk_size + m, chunk_size=chunk_size ) - z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) - z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z)) z = z + self.ps_dropout_row_layer( - self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) + self.tri_att_start(z, chunk_size=chunk_size) ) z = z + self.ps_dropout_col_layer( - self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) + self.tri_att_end(z, chunk_size=chunk_size) ) z = z + self.pair_transition( - z, mask=pair_trans_mask, chunk_size=chunk_size + z, chunk_size=chunk_size ) return m, z @@ -274,22 +269,16 @@ def __init__(self, def forward(self, m: torch.Tensor, z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, chunk_size: Optional[int] = None, - _mask_trans: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: m = m + self.msa_dropout_layer( - self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) + self.msa_att_row(m, z=z, chunk_size=chunk_size) ) - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m = m + self.msa_att_col(m, chunk_size=chunk_size) m, z = self.core( m, z, - msa_mask=msa_mask, - pair_mask=pair_mask, chunk_size=chunk_size, - _mask_trans=_mask_trans, ) return m, z diff --git a/openfold/msa.py b/openfold/msa.py index 172b26def5f1..00b822e7f390 100644 --- a/openfold/msa.py +++ b/openfold/msa.py @@ -136,45 +136,6 @@ def _prep_inputs(self, return m, mask_bias, z - @torch.jit.ignore - def _chunked_msa_attn(self, - m: torch.Tensor, - z: Optional[torch.Tensor], - mask: Optional[torch.Tensor], - chunk_logits: int, - checkpoint: bool, - ) -> torch.Tensor: - MSA_DIM = -4 - - def _get_qkv(m, z): - m, mask_bias, z = self._prep_inputs(m, z, mask) - q, k, v = self.mha._prep_qkv(m, m) - return m, q, k, v, mask_bias, z - - checkpoint_fn = get_checkpoint_fn() - - if(torch.is_grad_enabled() and checkpoint): - m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) - else: - m, q, k, v, mask_bias, z = _get_qkv(m, z) - - o = _attention_chunked_trainable( - query=q, - key=k, - value=v, - biases=[mask_bias, z], - chunk_size=chunk_logits, - chunk_dim=MSA_DIM, - checkpoint=checkpoint, - ) - - if(torch.is_grad_enabled() and checkpoint): - # Storing an additional m here is far from ideal - m = checkpoint_fn(self.mha._wrap_up, o, m) - else: - m = self.mha._wrap_up(o, m) - - return m def forward(self, m: torch.Tensor, @@ -199,12 +160,6 @@ def forward(self, cost of slower execution. Chunking is not performed by default. """ - if(_chunk_logits is not None): - return self._chunked_msa_attn( - m=m, z=z, mask=mask, - chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks - ) - m, mask_bias, z = self._prep_inputs(m, z, mask) biases = [mask_bias] @@ -306,15 +261,11 @@ def forward(self, """ # [*, N_res, N_seq, C_in] m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) - m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + m = self._msa_att(m, chunk_size=chunk_size) # [*, N_seq, N_res, C_in] m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) return m @@ -344,12 +295,10 @@ def __init__( @torch.jit.ignore def _chunk(self, m: torch.Tensor, - mask: torch.Tensor, chunk_size: int, ) -> torch.Tensor: mha_input = { "m": m, - "mask": mask, } return chunk_layer( self.global_attention, @@ -361,30 +310,20 @@ def _chunk(self, def forward( self, m: torch.Tensor, - mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: n_seq, n_res, c_in = m.shape[-3:] - if mask is None: - # [*, N_seq, N_res] - mask = torch.ones( - m.shape[:-1], - dtype=m.dtype, - device=m.device, - ).detach() - # [*, N_res, N_seq, C_in] m = m.transpose(-2, -3) - mask = mask.transpose(-1, -2) # [*, N_res, N_seq, C_in] m = self.layer_norm_m(m) if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) + m = self._chunk(m, chunk_size) else: - m = self.global_attention(m=m, mask=mask) + m = self.global_attention(m=m) # [*, N_seq, N_res, C_in] m = m.transpose(-2, -3) From f7d8092c84eef1a5dfd976f883a6d38d5b11bd68 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 16:01:05 +0800 Subject: [PATCH 072/209] align openfold --- autochunk_benchmark.py | 41 ++++++- evoformer_openfold/evoformer.py | 59 --------- evoformer_openfold/initializer.py | 29 ----- evoformer_openfold/kernel.py | 19 --- evoformer_openfold/msa.py | 95 --------------- evoformer_openfold/ops.py | 176 --------------------------- evoformer_openfold/triangle.py | 192 ----------------------------- openfold/evoformer.py | 194 ------------------------------ 8 files changed, 36 insertions(+), 769 deletions(-) delete mode 100644 evoformer_openfold/evoformer.py delete mode 100755 evoformer_openfold/initializer.py delete mode 100644 evoformer_openfold/kernel.py delete mode 100644 evoformer_openfold/msa.py delete mode 100755 evoformer_openfold/ops.py delete mode 100644 evoformer_openfold/triangle.py diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 20f615b216f7..679016438c59 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -9,20 +9,27 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from evoformer.evoformer import evoformer_base +from openfold.evoformer import EvoformerBlock -def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): +def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 loop = 16 with torch.no_grad(): for _ in range(loop // 4): - model(node, pair) + if chunk_size: + model(node, pair, chunk_size) + else: + model(node, pair) torch.cuda.synchronize() time1 = time.time() for _ in range(loop): - model(node, pair) + if chunk_size: + model(node, pair, chunk_size) + else: + model(node, pair) torch.cuda.synchronize() time2 = time.time() @@ -64,6 +71,26 @@ def _build_autochunk(model, max_memory, node, pair): return gm +def _build_openfold(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).cuda() + return model + + def benchmark_evoformer(): # init data and model msa_len = 300 @@ -74,10 +101,14 @@ def benchmark_evoformer(): # build autochunk model max_memory = 3000 # MB - autochunk = _build_autochunk(model, max_memory, node, pair) + autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) + + # build openfold + openfold = _build_openfold() # benchmark - _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(model, node, pair, "base") + _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=4) _benchmark_evoformer(autochunk, node, pair, "autochunk") diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py deleted file mode 100644 index cfd2bb2a2529..000000000000 --- a/evoformer_openfold/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py deleted file mode 100755 index c6ce0659e597..000000000000 --- a/evoformer_openfold/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py deleted file mode 100644 index 26ab5dc53261..000000000000 --- a/evoformer_openfold/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py deleted file mode 100644 index cac456638a55..000000000000 --- a/evoformer_openfold/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py deleted file mode 100755 index 611b7b0fe777..000000000000 --- a/evoformer_openfold/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py deleted file mode 100644 index f479469c3836..000000000000 --- a/evoformer_openfold/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/openfold/evoformer.py b/openfold/evoformer.py index 7fbcd8a76b4d..ffd4c982987a 100644 --- a/openfold/evoformer.py +++ b/openfold/evoformer.py @@ -284,104 +284,6 @@ def forward(self, return m, z -class ExtraMSABlock(nn.Module): - """ - Almost identical to the standard EvoformerBlock, except in that the - ExtraMSABlock uses GlobalAttention for MSA column attention and - requires more fine-grained control over checkpointing. Separated from - its twin to preserve the TorchScript-ability of the latter. - """ - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - is_multimer: bool, - ): - super(ExtraMSABlock, self).__init__() - - self.ckpt = ckpt - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnGlobalAttention( - c_in=c_m, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - eps=eps, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - self.is_multimer = is_multimer - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = 1024, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer( - self.msa_att_row( - m.clone(), - z=z.clone(), - mask=msa_mask, - chunk_size=chunk_size, - _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, - _checkpoint_chunks= - self.ckpt if torch.is_grad_enabled() else False, - ) - ) - - def fn(m, z): - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) - m, z = self.core( - m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size - ) - - return m, z - - if(torch.is_grad_enabled() and self.ckpt): - checkpoint_fn = get_checkpoint_fn() - m, z = checkpoint_fn(fn, m, z) - else: - m, z = fn(m, z) - - return m, z - - class EvoformerStack(nn.Module): """ Main Evoformer trunk. @@ -527,99 +429,3 @@ def block_with_cache_clear(block, *args): s = self.linear(m[..., 0, :, :]) return m, z, s - - -class ExtraMSAStack(nn.Module): - """ - Implements Algorithm 18. - """ - - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - super(ExtraMSAStack, self).__init__() - - self.clear_cache_between_blocks = clear_cache_between_blocks - self.blocks = nn.ModuleList() - for _ in range(no_blocks): - block = ExtraMSABlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ckpt=ckpt, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: int, - msa_mask: Optional[torch.Tensor] = None, - pair_mask: Optional[torch.Tensor] = None, - _mask_trans: bool = True, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_extra, N_res, C_m] extra MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - Optional [*, N_extra, N_res] MSA mask - pair_mask: - Optional [*, N_res, N_res] pair mask - Returns: - [*, N_res, N_res, C_z] pair update - """ - #checkpoint_fn = get_checkpoint_fn() - #blocks = [ - # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks - #] - - #def dodo(b, *args): - # torch.cuda.empty_cache() - # return b(*args) - - #blocks = [partial(dodo, b) for b in blocks] - - #for b in blocks: - # if(torch.is_grad_enabled()): - # m, z = checkpoint_fn(b, *(m, z)) - # else: - # m, z = b(m, z) - - for b in self.blocks: - m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) - - if(self.clear_cache_between_blocks): - torch.cuda.empty_cache() - - return z \ No newline at end of file From f5515e9978564bddc0ff97c06c7a6933668e7cef Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 16:55:47 +0800 Subject: [PATCH 073/209] use max_mem to control stratge --- chunk_codegen.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 033db50dbccb..1c8be65d490a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1398,14 +1398,18 @@ def estimate_chunk_inference_mem( class ChunkSelector(object): def __init__( - self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge, max_memory=None + self, + index_tracer: IndexTracer, + memory_estimator: MemoryEstimator, + max_memory=None, ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator - assert stratge in ["min_memory", "fit_memory"] - assert (stratge == "fit_memory" and max_memory is not None) or stratge != "fit_memory" - self.stratge = stratge - self.max_memory = max_memory # MB + if max_memory is not None: + self.stratge = "fit_memory" + self.max_memory = max_memory # MB + else: + self.stratge = "min_memory" def _select_best_chunk_region( self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak @@ -1538,6 +1542,8 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): possible_chunk_regions.remove(i) max_region_range = 0 best_region = None + if best_region is not None: + best_region["chunk_size"] = 2 return best_region def _is_legal_region(self, cur_chunk_info, chunk_infos): @@ -1563,7 +1569,7 @@ def __init__(self, gm, max_memory=None) -> None: self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, stratge="fit_memory", max_memory=max_memory + self.index_tracer, self.memory_estimator, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -2233,7 +2239,7 @@ def emit_node(node: Node, body): delete_unused_values, self.meta_node, self.meta_graph, - self.max_memory + self.max_memory, ) if len(body) == 0: From e5a5fbb8a94313722542b72f601b8433eef1e5dc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 01:00:06 +0800 Subject: [PATCH 074/209] update source add --- chunk_codegen.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1c8be65d490a..de58a61b943b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -133,24 +133,28 @@ def _inherit_all_computation(self, node_from, node_to): def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): node_from_dim = self._transform_index(node_from, node_from_dim) - node_from_trace = self._find_trace_from_node(node_from) + node_from_trace_source = self._find_source_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) - node_to_trace = self._find_trace_from_node(node_to) + node_to_trace_source = self._find_source_trace_from_node(node_to) node_from_idx = _find_idx_by_name(node_from.name, self.node_list) if init: - node_to_trace["source"][node_to_dim] = {} + node_to_trace_source[node_to_dim] = {} # add dim to cur new source - if node_from_idx not in node_to_trace["source"][node_to_dim]: - node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] + if node_from_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] else: - if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: - node_to_trace["source"][node_to_dim][node_from_idx].append( + if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: + node_to_trace_source[node_to_dim][node_from_idx].append( node_from_dim ) # update inputs source - node_to_trace["source"][node_to_dim].update( - node_from_trace["source"][node_from_dim] - ) + for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): + if node_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) + else: + for d in node_dim: + if d not in node_to_trace_source[node_to_dim][node_idx]: + node_to_trace_source[node_to_dim][node_idx].append(d) def _mark_computation_from_node(self, node_from, node_to, exclude=None): if exclude == None: @@ -1761,9 +1765,9 @@ def search_region(self): ) if self._stop_search(init_mem_peak, mem_peak): break - # self.memory_estimator.estimate_chunk_inference_mem( - # self.index_tracer.node_list, chunk_infos, print_mem=True - # ) + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) return chunk_infos From 966e4ea0cbf1cd17696aa90b6b9bd4a6999cfba4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 02:20:07 +0800 Subject: [PATCH 075/209] add reorder in mem estimator --- chunk_codegen.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index de58a61b943b..e20d151da1fb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1040,11 +1040,13 @@ def _reorder_chunk_info(self, chunk_info, reorder_map): chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), chunk_info["region"][1], ) + new_inputs_dim = [] for idx, input_dim in enumerate(chunk_info["inputs_dim"]): new_input_dim = {} for k, v in input_dim.items(): new_input_dim[reorder_map[k]] = v - chunk_info["inputs_dim"][idx] = new_input_dim + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim return chunk_info def _update_all_reorder_map(self, reorder_map): @@ -1095,11 +1097,24 @@ def reorder_node_list(self, node_list): for old_idx, new_idx in self.all_reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: - self.index_tracer = index_tracer + pass def _get_meta_node_size(self, x): x = x.meta["tensor_meta"] @@ -1453,9 +1468,11 @@ def _select_fit_memory_chunk_region( # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: - cur_chunk_infos = chunk_infos + [region] + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region) + cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_region_peak = cur_mem_peak[ max_chunk_region[0] : max_chunk_region[1] + 1 @@ -1492,9 +1509,11 @@ def _get_fit_chunk_size(self, chunk_info, chunk_infos): while cur_chunk_max_mem < self.max_memory: chunk_size *= 2 chunk_info["chunk_size"] = chunk_size - cur_chunk_infos = chunk_infos + [chunk_info] + cur_chunk_info = chunk_info.copy() + cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) + cur_chunk_infos = chunk_infos + [cur_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1511,11 +1530,13 @@ def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): else: gap = 1 while r >= l + gap: - mid = int(l + (r - l) / 2) + mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid - cur_chunk_infos = chunk_infos + [chunk_info] + cur_chunk_info = chunk_info.copy() + cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) + cur_chunk_infos = chunk_infos + [cur_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1529,7 +1550,7 @@ def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): def _get_compute_node_num(self, start, end): count = 0 for i in self.index_tracer.node_list[start : end + 1]: - if _is_non_compute_node(i): + if not _is_non_compute_node(i): count += 1 return count @@ -1547,7 +1568,7 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None if best_region is not None: - best_region["chunk_size"] = 2 + best_region["chunk_size"] = 1 return best_region def _is_legal_region(self, cur_chunk_info, chunk_infos): From 80efd70c725b00c236b80b68393c0d13ec457b0b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 13:44:46 +0800 Subject: [PATCH 076/209] improve reorder efficeincy --- chunk_codegen.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e20d151da1fb..7c334c617c7b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1486,6 +1486,8 @@ def _select_fit_memory_chunk_region( "chunk_len": self._get_compute_node_num( region["region"][0], region["region"][1] ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list } ) # no region found @@ -1495,48 +1497,47 @@ def _select_fit_memory_chunk_region( # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) - best_region = regions_dict[best_region_idx]["chunk_info"] + best_region = regions_dict[best_region_idx] # get max chunk size best_region = self._get_fit_chunk_size(best_region, chunk_infos) return best_region - def _get_fit_chunk_size(self, chunk_info, chunk_infos): + def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size = 1 - chunk_info["chunk_size"] = chunk_size + reorder_chunk_info = chunk_region_dict['reorder_chunk_info'] + reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_max_mem = 0 # search a region while cur_chunk_max_mem < self.max_memory: chunk_size *= 2 - chunk_info["chunk_size"] = chunk_size - cur_chunk_info = chunk_info.copy() - cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) - cur_chunk_infos = chunk_infos + [cur_chunk_info] + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos + chunk_region_dict['reorder_node_list'], cur_chunk_infos )[0] cur_chunk_max_mem = max( - cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1] ) # search exact size + chunk_info = chunk_region_dict["chunk_info"] chunk_info["chunk_size"] = self._chunk_size_binary_search( - chunk_size // 2, chunk_size, chunk_info, chunk_infos + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos ) return chunk_info - def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): + def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): if l >= 16: gap = 4 else: gap = 1 + chunk_info = chunk_region_dict['reorder_chunk_info'] while r >= l + gap: mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid - cur_chunk_info = chunk_info.copy() - cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) - cur_chunk_infos = chunk_infos + [cur_chunk_info] + cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos + chunk_region_dict['reorder_node_list'], cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list): def _replace_name(context, name_from, name_to): - patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")] + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] for p in patterns: source = p[0] + name_from + p[1] target = p[0] + name_to + p[1] From 5f24f4fd55956904d024d8835029ffcd0cc203a5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 16:29:43 +0800 Subject: [PATCH 077/209] support ones_like, add prompt if fit mode search fail --- chunk_codegen.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 7c334c617c7b..6f8ff2b23ff0 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1406,9 +1406,9 @@ def estimate_chunk_inference_mem( # self._print_mem_log(act_memory_peak_log, node_list, "peak") # self._print_mem_log(act_memory_after_node_log, node_list, "after") self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log( - act_memory_after_node_log, node_list, "after" - ) + # self._print_compute_op_mem_log( + # act_memory_after_node_log, node_list, "after" + # ) # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1465,6 +1465,9 @@ def _select_fit_memory_chunk_region( if i in possible_chunk_regions: possible_chunk_regions.remove(i) + if len(possible_chunk_regions) == 0: + return None + # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: @@ -1492,7 +1495,7 @@ def _select_fit_memory_chunk_region( ) # no region found if len(regions_dict) == 0: - return None + raise RuntimeError("Search failed. Try a larger memory threshold.") # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] @@ -1995,6 +1998,14 @@ def emit_code_with_chunk( body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) + # ones like + if "ones_like" in node.name: + chunk_slice = _gen_chunk_slice_dim( + chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) body[-1] = _replace_reshape_size( body[-1], node.name, chunk_search[region_idx]["reshape_size"] ) From 7fd3b45af21345cff9334682e277d7669c730814 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 2 Jan 2023 00:04:47 +0800 Subject: [PATCH 078/209] fix a bug in ones like, dont gen chunk if dim size is 1 --- autochunk_benchmark.py | 4 ++-- chunk_codegen.py | 41 +++++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 679016438c59..3b48d7e461fe 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -16,9 +16,9 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 - loop = 16 + loop = 3 with torch.no_grad(): - for _ in range(loop // 4): + for _ in range(loop // 2 + 1): if chunk_size: model(node, pair, chunk_size) else: diff --git a/chunk_codegen.py b/chunk_codegen.py index 6f8ff2b23ff0..6f21f26f37e1 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -144,9 +144,7 @@ def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] else: if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: - node_to_trace_source[node_to_dim][node_from_idx].append( - node_from_dim - ) + node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) # update inputs source for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): if node_idx not in node_to_trace_source[node_to_dim]: @@ -1097,17 +1095,17 @@ def reorder_node_list(self, node_list): for old_idx, new_idx in self.all_reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] return new_node_list - + def tmp_reorder(self, node_list, chunk_info): if len(chunk_info["args"]["prepose_nodes"]) == 0: return node_list, chunk_info reorder_map = self._get_reorder_map(chunk_info) - + # new tmp node list new_node_list = [None for _ in range(len(node_list))] for old_idx, new_idx in reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] - + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) return new_node_list, chunk_info @@ -1472,7 +1470,9 @@ def _select_fit_memory_chunk_region( regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region) + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_node_list, cur_chunk_infos @@ -1490,7 +1490,7 @@ def _select_fit_memory_chunk_region( region["region"][0], region["region"][1] ), "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list + "reorder_node_list": cur_node_list, } ) # no region found @@ -1508,7 +1508,7 @@ def _select_fit_memory_chunk_region( def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size = 1 - reorder_chunk_info = chunk_region_dict['reorder_chunk_info'] + reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_max_mem = 0 # search a region @@ -1517,10 +1517,13 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict['reorder_node_list'], cur_chunk_infos + chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( - cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1] + cur_mem_peak[ + reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + + 1 + ] ) # search exact size chunk_info = chunk_region_dict["chunk_info"] @@ -1534,13 +1537,13 @@ def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): gap = 4 else: gap = 1 - chunk_info = chunk_region_dict['reorder_chunk_info'] + chunk_info = chunk_region_dict["reorder_chunk_info"] while r >= l + gap: mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict['reorder_node_list'], cur_chunk_infos + chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -2000,8 +2003,18 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: + chunk_dim = chunk_search[region_idx]["node_chunk_dim"][ + chunk_region_search.index_tracer.node_list[node_idx] + ]["chunk_dim"] + if ( + _get_node_shape( + chunk_region_search.index_tracer.node_list[node_idx] + )[chunk_dim] + == 1 + ): + continue chunk_slice = _gen_chunk_slice_dim( - chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node) + chunk_dim, "chunk_idx", _get_node_shape(node) ) body[-1] = _replace_name( body[-1], node.args[0].name, node.args[0].name + chunk_slice From 9c5e028a62b003136d2402b99b728eaefcc528cd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 2 Jan 2023 00:27:11 +0800 Subject: [PATCH 079/209] fix bug again --- chunk_codegen.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 6f21f26f37e1..21ecc343a959 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -2003,22 +2003,25 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - chunk_dim = chunk_search[region_idx]["node_chunk_dim"][ - chunk_region_search.index_tracer.node_list[node_idx] - ]["chunk_dim"] - if ( - _get_node_shape( - chunk_region_search.index_tracer.node_list[node_idx] - )[chunk_dim] - == 1 - ): - continue - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", _get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + meta_node = chunk_region_search.index_tracer.node_list[node_idx] + chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][ + "chunk_dim" + ] + if _get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_search[region_idx]["node_chunk_dim"] + or chunk_search[region_idx]["node_chunk_dim"][source_node][ + "chunk_dim" + ] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", _get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) body[-1] = _replace_reshape_size( body[-1], node.name, chunk_search[region_idx]["reshape_size"] ) From 55cb713f36e8080313225577dde97e4d35e18108 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 5 Jan 2023 11:29:22 +0800 Subject: [PATCH 080/209] update min memory stratege, reduce mem usage by 30% --- chunk_codegen.py | 65 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 21ecc343a959..41fcb5a3c2f4 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1433,7 +1433,11 @@ def _select_best_chunk_region( ): if self.stratge == "min_memory": best_region = self._select_min_memory_chunk_region( - possible_chunk_regions, chunk_infos + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, ) elif self.stratge == "fit_memory": best_region = self._select_fit_memory_chunk_region( @@ -1561,19 +1565,52 @@ def _get_compute_node_num(self, start, end): count += 1 return count - def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): - max_region_range = 0 - best_region = None - while len(possible_chunk_regions) > 0: - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_region = i - max_region_range = i["region"][1] - i["region"][0] - if self._is_legal_region(best_region, chunk_infos): - break - possible_chunk_regions.remove(i) - max_region_range = 0 - best_region = None + def _select_min_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + + # select the min mem + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] + best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) + best_region = regions_dict[best_region_idx]["chunk_info"] if best_region is not None: best_region["chunk_size"] = 1 return best_region From 71e72c48907195096ef02be73e1c5b0feea2653d Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 5 Jan 2023 17:54:25 +0800 Subject: [PATCH 081/209] last version of benchmark --- autochunk_benchmark.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 3b48d7e461fe..c938485efc05 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -93,22 +93,24 @@ def _build_openfold(): def benchmark_evoformer(): # init data and model - msa_len = 300 - pair_len = 800 + msa_len = 256 + pair_len = 2048 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = evoformer_base().cuda() # build autochunk model - max_memory = 3000 # MB + max_memory = 10000 # MB fit memory mode + # max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold + chunk_size = 64 openfold = _build_openfold() # benchmark _benchmark_evoformer(model, node, pair, "base") - _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=4) + _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) _benchmark_evoformer(autochunk, node, pair, "autochunk") From 27ab5240965fc9cc0ec74ff48356abcbf098bd74 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 11:07:57 +0800 Subject: [PATCH 082/209] refactor structure --- .../chunk_codegen.py | 41 ++++++++----------- .../evoformer}/evoformer.py | 0 .../evoformer}/initializer.py | 0 {evoformer => autochunk/evoformer}/kernel.py | 0 {evoformer => autochunk/evoformer}/msa.py | 0 {evoformer => autochunk/evoformer}/ops.py | 0 .../evoformer}/triangle.py | 0 .../openfold}/checkpointing.py | 0 {openfold => autochunk/openfold}/dropout.py | 0 {openfold => autochunk/openfold}/evoformer.py | 0 {openfold => autochunk/openfold}/msa.py | 0 .../openfold}/outer_product_mean.py | 0 .../openfold}/pair_transition.py | 0 .../openfold}/primitives.py | 0 .../openfold}/tensor_utils.py | 0 .../openfold}/triangular_attention.py | 0 .../triangular_multiplicative_update.py | 0 autochunk_benchmark.py | 18 ++++---- chunk_codegen_run.py => autochunk_test.py | 4 +- 19 files changed, 29 insertions(+), 34 deletions(-) rename chunk_codegen.py => autochunk/chunk_codegen.py (98%) rename {evoformer => autochunk/evoformer}/evoformer.py (100%) rename {evoformer => autochunk/evoformer}/initializer.py (100%) rename {evoformer => autochunk/evoformer}/kernel.py (100%) rename {evoformer => autochunk/evoformer}/msa.py (100%) rename {evoformer => autochunk/evoformer}/ops.py (100%) rename {evoformer => autochunk/evoformer}/triangle.py (100%) rename {openfold => autochunk/openfold}/checkpointing.py (100%) rename {openfold => autochunk/openfold}/dropout.py (100%) rename {openfold => autochunk/openfold}/evoformer.py (100%) rename {openfold => autochunk/openfold}/msa.py (100%) rename {openfold => autochunk/openfold}/outer_product_mean.py (100%) rename {openfold => autochunk/openfold}/pair_transition.py (100%) rename {openfold => autochunk/openfold}/primitives.py (100%) rename {openfold => autochunk/openfold}/tensor_utils.py (100%) rename {openfold => autochunk/openfold}/triangular_attention.py (100%) rename {openfold => autochunk/openfold}/triangular_multiplicative_update.py (100%) rename chunk_codegen_run.py => autochunk_test.py (97%) diff --git a/chunk_codegen.py b/autochunk/chunk_codegen.py similarity index 98% rename from chunk_codegen.py rename to autochunk/chunk_codegen.py index 41fcb5a3c2f4..7a5d06689247 100644 --- a/chunk_codegen.py +++ b/autochunk/chunk_codegen.py @@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict): def emit_code_with_chunk( body, - ckpt_func, nodes, emit_node_func, delete_unused_value_func, - meta_nodes, - meta_graph, - max_memory=None, + chunk_region_search, + chunk_infos ): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use @@ -1988,23 +1986,19 @@ def emit_code_with_chunk( """ node_list = list(nodes) - # find the chunk regions - chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) - chunk_search = chunk_region_search.search_region() - - chunk_regions = [i["region"] for i in chunk_search] + chunk_regions = [i["region"] for i in chunk_infos] chunk_starts = [i[0] for i in chunk_regions] chunk_ends = [i[1] for i in chunk_regions] - chunk_inputs = [i["inputs"] for i in chunk_search] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] - chunk_outputs = [i["outputs"][0] for i in chunk_search] - chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] + chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) node_idx = 0 @@ -2022,7 +2016,7 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], - chunk_search[region_idx]["chunk_size"], + chunk_infos[region_idx]["chunk_size"], ) ) @@ -2041,14 +2035,14 @@ def emit_code_with_chunk( # ones like if "ones_like" in node.name: meta_node = chunk_region_search.index_tracer.node_list[node_idx] - chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][ + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ "chunk_dim" ] if _get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] if ( - source_node not in chunk_search[region_idx]["node_chunk_dim"] - or chunk_search[region_idx]["node_chunk_dim"][source_node][ + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node][ "chunk_dim" ] is None @@ -2060,7 +2054,7 @@ def emit_code_with_chunk( body[-1], node.args[0].name, node.args[0].name + chunk_slice ) body[-1] = _replace_reshape_size( - body[-1], node.name, chunk_search[region_idx]["reshape_size"] + body[-1], node.name, chunk_infos[region_idx]["reshape_size"] ) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) @@ -2092,6 +2086,9 @@ def __init__(self, meta_graph, max_memory=None): self.meta_graph = meta_graph self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) + self.chunk_infos = self.chunk_region_search.search_region() def _gen_python_code( self, nodes, root_module: str, namespace: _Namespace @@ -2323,13 +2320,11 @@ def emit_node(node: Node, body): # will use nested type of activation checkpoint codegen emit_code_with_chunk( body, - ckpt_func, nodes, emit_node, delete_unused_values, - self.meta_node, - self.meta_graph, - self.max_memory, + self.chunk_region_search, + self.chunk_infos ) if len(body) == 0: diff --git a/evoformer/evoformer.py b/autochunk/evoformer/evoformer.py similarity index 100% rename from evoformer/evoformer.py rename to autochunk/evoformer/evoformer.py diff --git a/evoformer/initializer.py b/autochunk/evoformer/initializer.py similarity index 100% rename from evoformer/initializer.py rename to autochunk/evoformer/initializer.py diff --git a/evoformer/kernel.py b/autochunk/evoformer/kernel.py similarity index 100% rename from evoformer/kernel.py rename to autochunk/evoformer/kernel.py diff --git a/evoformer/msa.py b/autochunk/evoformer/msa.py similarity index 100% rename from evoformer/msa.py rename to autochunk/evoformer/msa.py diff --git a/evoformer/ops.py b/autochunk/evoformer/ops.py similarity index 100% rename from evoformer/ops.py rename to autochunk/evoformer/ops.py diff --git a/evoformer/triangle.py b/autochunk/evoformer/triangle.py similarity index 100% rename from evoformer/triangle.py rename to autochunk/evoformer/triangle.py diff --git a/openfold/checkpointing.py b/autochunk/openfold/checkpointing.py similarity index 100% rename from openfold/checkpointing.py rename to autochunk/openfold/checkpointing.py diff --git a/openfold/dropout.py b/autochunk/openfold/dropout.py similarity index 100% rename from openfold/dropout.py rename to autochunk/openfold/dropout.py diff --git a/openfold/evoformer.py b/autochunk/openfold/evoformer.py similarity index 100% rename from openfold/evoformer.py rename to autochunk/openfold/evoformer.py diff --git a/openfold/msa.py b/autochunk/openfold/msa.py similarity index 100% rename from openfold/msa.py rename to autochunk/openfold/msa.py diff --git a/openfold/outer_product_mean.py b/autochunk/openfold/outer_product_mean.py similarity index 100% rename from openfold/outer_product_mean.py rename to autochunk/openfold/outer_product_mean.py diff --git a/openfold/pair_transition.py b/autochunk/openfold/pair_transition.py similarity index 100% rename from openfold/pair_transition.py rename to autochunk/openfold/pair_transition.py diff --git a/openfold/primitives.py b/autochunk/openfold/primitives.py similarity index 100% rename from openfold/primitives.py rename to autochunk/openfold/primitives.py diff --git a/openfold/tensor_utils.py b/autochunk/openfold/tensor_utils.py similarity index 100% rename from openfold/tensor_utils.py rename to autochunk/openfold/tensor_utils.py diff --git a/openfold/triangular_attention.py b/autochunk/openfold/triangular_attention.py similarity index 100% rename from openfold/triangular_attention.py rename to autochunk/openfold/triangular_attention.py diff --git a/openfold/triangular_multiplicative_update.py b/autochunk/openfold/triangular_multiplicative_update.py similarity index 100% rename from openfold/triangular_multiplicative_update.py rename to autochunk/openfold/triangular_multiplicative_update.py diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index c938485efc05..c34b5217e5d4 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -3,13 +3,13 @@ import torch import torch.fx -from chunk_codegen import ChunkCodeGen +from autochunk.chunk_codegen import ChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor -from evoformer.evoformer import evoformer_base -from openfold.evoformer import EvoformerBlock +from autochunk.evoformer.evoformer import evoformer_base +from autochunk.openfold.evoformer import EvoformerBlock def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): @@ -94,23 +94,23 @@ def _build_openfold(): def benchmark_evoformer(): # init data and model msa_len = 256 - pair_len = 2048 + pair_len = 1024 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = evoformer_base().cuda() # build autochunk model - max_memory = 10000 # MB fit memory mode - # max_memory = None # min memory mode + # max_memory = 10000 # MB fit memory mode + max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold chunk_size = 64 - openfold = _build_openfold() + # openfold = _build_openfold() # benchmark - _benchmark_evoformer(model, node, pair, "base") - _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) + # _benchmark_evoformer(model, node, pair, "base") + # _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) _benchmark_evoformer(autochunk, node, pair, "autochunk") diff --git a/chunk_codegen_run.py b/autochunk_test.py similarity index 97% rename from chunk_codegen_run.py rename to autochunk_test.py index 3a3b3c599e3e..63f393531d5c 100644 --- a/chunk_codegen_run.py +++ b/autochunk_test.py @@ -12,8 +12,8 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.profiler import MetaTensor -from evoformer.evoformer import evoformer_base -from chunk_codegen import ChunkCodeGen +from autochunk.evoformer.evoformer import evoformer_base +from autochunk.chunk_codegen import ChunkCodeGen with_codegen = True From efb1c64c30cf2ee35dad03bfd3829f014d204a8d Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 11:39:26 +0800 Subject: [PATCH 083/209] restruct dir --- .../autochunk}/chunk_codegen.py | 0 .../test_autochunk/autochunk_benchmark.py | 14 +++++++------- .../test_autochunk/autochunk_test.py | 4 ++-- .../test_autochunk}/evoformer/evoformer.py | 0 .../test_autochunk}/evoformer/initializer.py | 0 .../test_autochunk}/evoformer/kernel.py | 0 .../test_autochunk}/evoformer/msa.py | 0 .../test_autochunk}/evoformer/ops.py | 0 .../test_autochunk}/evoformer/triangle.py | 0 .../test_autochunk}/openfold/checkpointing.py | 0 .../test_autochunk}/openfold/dropout.py | 0 .../test_autochunk}/openfold/evoformer.py | 18 +++++++++--------- .../test_autochunk}/openfold/msa.py | 6 +++--- .../openfold/outer_product_mean.py | 4 ++-- .../openfold/pair_transition.py | 4 ++-- .../test_autochunk}/openfold/primitives.py | 4 ++-- .../test_autochunk}/openfold/tensor_utils.py | 0 .../openfold/triangular_attention.py | 4 ++-- .../triangular_multiplicative_update.py | 4 ++-- 19 files changed, 31 insertions(+), 31 deletions(-) rename {autochunk => colossalai/autochunk}/chunk_codegen.py (100%) rename autochunk_benchmark.py => tests/test_autochunk/autochunk_benchmark.py (89%) rename autochunk_test.py => tests/test_autochunk/autochunk_test.py (96%) rename {autochunk => tests/test_autochunk}/evoformer/evoformer.py (100%) rename {autochunk => tests/test_autochunk}/evoformer/initializer.py (100%) rename {autochunk => tests/test_autochunk}/evoformer/kernel.py (100%) rename {autochunk => tests/test_autochunk}/evoformer/msa.py (100%) rename {autochunk => tests/test_autochunk}/evoformer/ops.py (100%) rename {autochunk => tests/test_autochunk}/evoformer/triangle.py (100%) rename {autochunk => tests/test_autochunk}/openfold/checkpointing.py (100%) rename {autochunk => tests/test_autochunk}/openfold/dropout.py (100%) rename {autochunk => tests/test_autochunk}/openfold/evoformer.py (96%) rename {autochunk => tests/test_autochunk}/openfold/msa.py (98%) rename {autochunk => tests/test_autochunk}/openfold/outer_product_mean.py (97%) rename {autochunk => tests/test_autochunk}/openfold/pair_transition.py (96%) rename {autochunk => tests/test_autochunk}/openfold/primitives.py (99%) rename {autochunk => tests/test_autochunk}/openfold/tensor_utils.py (100%) rename {autochunk => tests/test_autochunk}/openfold/triangular_attention.py (97%) rename {autochunk => tests/test_autochunk}/openfold/triangular_multiplicative_update.py (97%) diff --git a/autochunk/chunk_codegen.py b/colossalai/autochunk/chunk_codegen.py similarity index 100% rename from autochunk/chunk_codegen.py rename to colossalai/autochunk/chunk_codegen.py diff --git a/autochunk_benchmark.py b/tests/test_autochunk/autochunk_benchmark.py similarity index 89% rename from autochunk_benchmark.py rename to tests/test_autochunk/autochunk_benchmark.py index c34b5217e5d4..8df6d9ff4564 100644 --- a/autochunk_benchmark.py +++ b/tests/test_autochunk/autochunk_benchmark.py @@ -3,13 +3,13 @@ import torch import torch.fx -from autochunk.chunk_codegen import ChunkCodeGen +from colossalai.autochunk.chunk_codegen import ChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor -from autochunk.evoformer.evoformer import evoformer_base -from autochunk.openfold.evoformer import EvoformerBlock +from tests.test_autochunk.evoformer.evoformer import evoformer_base +from tests.test_autochunk.openfold.evoformer import EvoformerBlock def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): @@ -94,7 +94,7 @@ def _build_openfold(): def benchmark_evoformer(): # init data and model msa_len = 256 - pair_len = 1024 + pair_len = 256 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = evoformer_base().cuda() @@ -106,11 +106,11 @@ def benchmark_evoformer(): # build openfold chunk_size = 64 - # openfold = _build_openfold() + openfold = _build_openfold() # benchmark - # _benchmark_evoformer(model, node, pair, "base") - # _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) + _benchmark_evoformer(model, node, pair, "base") + _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) _benchmark_evoformer(autochunk, node, pair, "autochunk") diff --git a/autochunk_test.py b/tests/test_autochunk/autochunk_test.py similarity index 96% rename from autochunk_test.py rename to tests/test_autochunk/autochunk_test.py index 63f393531d5c..5e9aaca15f9f 100644 --- a/autochunk_test.py +++ b/tests/test_autochunk/autochunk_test.py @@ -12,8 +12,8 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.profiler import MetaTensor -from autochunk.evoformer.evoformer import evoformer_base -from autochunk.chunk_codegen import ChunkCodeGen +from tests.test_autochunk.evoformer.evoformer import evoformer_base +from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen with_codegen = True diff --git a/autochunk/evoformer/evoformer.py b/tests/test_autochunk/evoformer/evoformer.py similarity index 100% rename from autochunk/evoformer/evoformer.py rename to tests/test_autochunk/evoformer/evoformer.py diff --git a/autochunk/evoformer/initializer.py b/tests/test_autochunk/evoformer/initializer.py similarity index 100% rename from autochunk/evoformer/initializer.py rename to tests/test_autochunk/evoformer/initializer.py diff --git a/autochunk/evoformer/kernel.py b/tests/test_autochunk/evoformer/kernel.py similarity index 100% rename from autochunk/evoformer/kernel.py rename to tests/test_autochunk/evoformer/kernel.py diff --git a/autochunk/evoformer/msa.py b/tests/test_autochunk/evoformer/msa.py similarity index 100% rename from autochunk/evoformer/msa.py rename to tests/test_autochunk/evoformer/msa.py diff --git a/autochunk/evoformer/ops.py b/tests/test_autochunk/evoformer/ops.py similarity index 100% rename from autochunk/evoformer/ops.py rename to tests/test_autochunk/evoformer/ops.py diff --git a/autochunk/evoformer/triangle.py b/tests/test_autochunk/evoformer/triangle.py similarity index 100% rename from autochunk/evoformer/triangle.py rename to tests/test_autochunk/evoformer/triangle.py diff --git a/autochunk/openfold/checkpointing.py b/tests/test_autochunk/openfold/checkpointing.py similarity index 100% rename from autochunk/openfold/checkpointing.py rename to tests/test_autochunk/openfold/checkpointing.py diff --git a/autochunk/openfold/dropout.py b/tests/test_autochunk/openfold/dropout.py similarity index 100% rename from autochunk/openfold/dropout.py rename to tests/test_autochunk/openfold/dropout.py diff --git a/autochunk/openfold/evoformer.py b/tests/test_autochunk/openfold/evoformer.py similarity index 96% rename from autochunk/openfold/evoformer.py rename to tests/test_autochunk/openfold/evoformer.py index ffd4c982987a..b53ec1aa51e5 100644 --- a/autochunk/openfold/evoformer.py +++ b/tests/test_autochunk/openfold/evoformer.py @@ -19,25 +19,25 @@ from typing import Tuple, Optional from functools import partial -from openfold.primitives import Linear, LayerNorm -from openfold.dropout import DropoutRowwise, DropoutColumnwise -from openfold.msa import ( +from .primitives import Linear, LayerNorm +from .dropout import DropoutRowwise, DropoutColumnwise +from .msa import ( MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention, ) -from openfold.outer_product_mean import OuterProductMean -from openfold.pair_transition import PairTransition -from openfold.triangular_attention import ( +from .outer_product_mean import OuterProductMean +from .pair_transition import PairTransition +from .triangular_attention import ( TriangleAttentionStartingNode, TriangleAttentionEndingNode, ) -from openfold.triangular_multiplicative_update import ( +from .triangular_multiplicative_update import ( TriangleMultiplicationOutgoing, TriangleMultiplicationIncoming, ) -from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn -from openfold.tensor_utils import chunk_layer +from .checkpointing import checkpoint_blocks, get_checkpoint_fn +from .tensor_utils import chunk_layer class MSATransition(nn.Module): diff --git a/autochunk/openfold/msa.py b/tests/test_autochunk/openfold/msa.py similarity index 98% rename from autochunk/openfold/msa.py rename to tests/test_autochunk/openfold/msa.py index 00b822e7f390..7c137286feab 100644 --- a/autochunk/openfold/msa.py +++ b/tests/test_autochunk/openfold/msa.py @@ -18,15 +18,15 @@ import torch.nn as nn from typing import Optional, List, Tuple -from openfold.primitives import ( +from .primitives import ( Linear, LayerNorm, Attention, GlobalAttention, _attention_chunked_trainable, ) -from openfold.checkpointing import get_checkpoint_fn -from openfold.tensor_utils import ( +from .checkpointing import get_checkpoint_fn +from .tensor_utils import ( chunk_layer, permute_final_dims, flatten_final_dims, diff --git a/autochunk/openfold/outer_product_mean.py b/tests/test_autochunk/openfold/outer_product_mean.py similarity index 97% rename from autochunk/openfold/outer_product_mean.py rename to tests/test_autochunk/openfold/outer_product_mean.py index 43d853833c66..daadf1c272cf 100644 --- a/autochunk/openfold/outer_product_mean.py +++ b/tests/test_autochunk/openfold/outer_product_mean.py @@ -19,8 +19,8 @@ import torch import torch.nn as nn -from openfold.primitives import Linear -from openfold.tensor_utils import chunk_layer +from .primitives import Linear +from .tensor_utils import chunk_layer class OuterProductMean(nn.Module): diff --git a/autochunk/openfold/pair_transition.py b/tests/test_autochunk/openfold/pair_transition.py similarity index 96% rename from autochunk/openfold/pair_transition.py rename to tests/test_autochunk/openfold/pair_transition.py index de76306418ee..7d09914dc3cc 100644 --- a/autochunk/openfold/pair_transition.py +++ b/tests/test_autochunk/openfold/pair_transition.py @@ -17,8 +17,8 @@ import torch import torch.nn as nn -from openfold.primitives import Linear, LayerNorm -from openfold.tensor_utils import chunk_layer +from .primitives import Linear, LayerNorm +from .tensor_utils import chunk_layer class PairTransition(nn.Module): diff --git a/autochunk/openfold/primitives.py b/tests/test_autochunk/openfold/primitives.py similarity index 99% rename from autochunk/openfold/primitives.py rename to tests/test_autochunk/openfold/primitives.py index bbc156f21d4a..32a9d487c441 100644 --- a/autochunk/openfold/primitives.py +++ b/tests/test_autochunk/openfold/primitives.py @@ -21,8 +21,8 @@ import torch import torch.nn as nn -from openfold.checkpointing import get_checkpoint_fn -from openfold.tensor_utils import ( +from .checkpointing import get_checkpoint_fn +from .tensor_utils import ( permute_final_dims, flatten_final_dims, _chunk_slice, diff --git a/autochunk/openfold/tensor_utils.py b/tests/test_autochunk/openfold/tensor_utils.py similarity index 100% rename from autochunk/openfold/tensor_utils.py rename to tests/test_autochunk/openfold/tensor_utils.py diff --git a/autochunk/openfold/triangular_attention.py b/tests/test_autochunk/openfold/triangular_attention.py similarity index 97% rename from autochunk/openfold/triangular_attention.py rename to tests/test_autochunk/openfold/triangular_attention.py index 6d3e37f4c681..12d09c502daf 100644 --- a/autochunk/openfold/triangular_attention.py +++ b/tests/test_autochunk/openfold/triangular_attention.py @@ -20,8 +20,8 @@ import torch import torch.nn as nn -from openfold.primitives import Linear, LayerNorm, Attention -from openfold.tensor_utils import ( +from .primitives import Linear, LayerNorm, Attention +from .tensor_utils import ( chunk_layer, permute_final_dims, flatten_final_dims, diff --git a/autochunk/openfold/triangular_multiplicative_update.py b/tests/test_autochunk/openfold/triangular_multiplicative_update.py similarity index 97% rename from autochunk/openfold/triangular_multiplicative_update.py rename to tests/test_autochunk/openfold/triangular_multiplicative_update.py index 2406e2bac2cf..29f7062c3212 100644 --- a/autochunk/openfold/triangular_multiplicative_update.py +++ b/tests/test_autochunk/openfold/triangular_multiplicative_update.py @@ -19,8 +19,8 @@ import torch import torch.nn as nn -from openfold.primitives import Linear, LayerNorm -from openfold.tensor_utils import permute_final_dims +from .primitives import Linear, LayerNorm +from .tensor_utils import permute_final_dims class TriangleMultiplicativeUpdate(nn.Module): From 06a5355d98c0069e3305679a04846637917078e9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 11:44:01 +0800 Subject: [PATCH 084/209] update test --- tests/test_autochunk/autochunk_test.py | 111 ++++++++++++------------- 1 file changed, 52 insertions(+), 59 deletions(-) diff --git a/tests/test_autochunk/autochunk_test.py b/tests/test_autochunk/autochunk_test.py index 5e9aaca15f9f..caa2d9a80254 100644 --- a/tests/test_autochunk/autochunk_test.py +++ b/tests/test_autochunk/autochunk_test.py @@ -1,76 +1,60 @@ -import copy -import torch -import torch.nn.functional as F import pytest +import torch import torch.fx import torch.multiprocessing as mp -from torch.fx import GraphModule -from colossalai.fx import ColoTracer + import colossalai -from colossalai.utils import free_port +from colossalai.autochunk.chunk_codegen import ChunkCodeGen from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor +from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base -from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen -with_codegen = True - - -def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: - for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad): - return False - return True - - -def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: - for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data): - return False - return True - - -def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # with torch.no_grad(): - # node0 = node.clone() - # pair0 = pair.clone() - # model.graph(node0, pair0, now_mem) - # new_now_mem = torch.cuda.memory_allocated() / 1024**2 - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) - + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 with torch.no_grad(): node1 = node.clone() pair1 = pair.clone() - gm(node1, pair1) + gm(node1, pair1) new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print("gm now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) - + print( + "autochunk now mem:%.2f max mem:%.2f" + % (new_now_mem - now_mem, new_max_mem - now_mem) + ) + # test forward with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0])) - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1])) - - # test barckward - # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() - # loss0.backward() - # loss1 = fx_out[0].sum() + fx_out[1].sum() - # loss1.backward() - # assert _is_all_param_close(model, gm) - # assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + assert torch.allclose( + non_fx_out[0], fx_out[0], atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0]) + ) + assert torch.allclose( + non_fx_out[1], fx_out[1], atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1]) + ) def _run_offload_codegen(rank): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) # build model and input model = evoformer_base().cuda() @@ -78,15 +62,25 @@ def _run_offload_codegen(rank): pair = torch.randn(1, 300, 300, 128).cuda() # trace the module and replace codegen - graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace - interp = MetaInfoProp(gm_prop) - interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + graph = ColoTracer().trace( + model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }, + ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) codegen = ChunkCodeGen(gm_prop) graph.set_codegen(codegen) @@ -94,15 +88,14 @@ def _run_offload_codegen(rank): gm.recompile() # assert we have all the components - code = graph.python_code("self").src - print(code) + # code = graph.python_code("self").src + # print(code) - _test_fwd_and_bwd(model, gm, node, pair) + _test_fwd(model, gm, node, pair) gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') -def test_act_ckpt_codegen(): +def test_autochunk(): mp.spawn(_run_offload_codegen, nprocs=1) From d1f07731824c425c26197c7c82425445c8c3df3e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 11:48:33 +0800 Subject: [PATCH 085/209] rename --- .../{autochunk_benchmark.py => benchmark_autochunk.py} | 0 tests/test_autochunk/{autochunk_test.py => test_autochunk.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/test_autochunk/{autochunk_benchmark.py => benchmark_autochunk.py} (100%) rename tests/test_autochunk/{autochunk_test.py => test_autochunk.py} (100%) diff --git a/tests/test_autochunk/autochunk_benchmark.py b/tests/test_autochunk/benchmark_autochunk.py similarity index 100% rename from tests/test_autochunk/autochunk_benchmark.py rename to tests/test_autochunk/benchmark_autochunk.py diff --git a/tests/test_autochunk/autochunk_test.py b/tests/test_autochunk/test_autochunk.py similarity index 100% rename from tests/test_autochunk/autochunk_test.py rename to tests/test_autochunk/test_autochunk.py From 1a6d2a740be33d769111ed03104bb5fa73b2ad50 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:14:45 +0800 Subject: [PATCH 086/209] take apart chunk code gen --- colossalai/autochunk/autochunk_codegen.py | 497 ++++ colossalai/autochunk/chunk_codegen.py | 2364 ------------------- colossalai/autochunk/chunk_region_search.py | 211 ++ colossalai/autochunk/chunk_selector.py | 221 ++ colossalai/autochunk/index_tracer.py | 1056 +++++++++ colossalai/autochunk/memory_estiamtor.py | 318 +++ colossalai/autochunk/utils.py | 95 + tests/test_autochunk/benchmark_autochunk.py | 12 +- tests/test_autochunk/test_autochunk.py | 4 +- 9 files changed, 2408 insertions(+), 2370 deletions(-) create mode 100644 colossalai/autochunk/autochunk_codegen.py delete mode 100644 colossalai/autochunk/chunk_codegen.py create mode 100644 colossalai/autochunk/chunk_region_search.py create mode 100644 colossalai/autochunk/chunk_selector.py create mode 100644 colossalai/autochunk/index_tracer.py create mode 100644 colossalai/autochunk/memory_estiamtor.py create mode 100644 colossalai/autochunk/utils.py diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py new file mode 100644 index 000000000000..58a8c375136e --- /dev/null +++ b/colossalai/autochunk/autochunk_codegen.py @@ -0,0 +1,497 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch +from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, +) +from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + +import colossalai + +from .chunk_region_search import ChunkRegionSearch +from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape + +CODEGEN_AVAILABLE = True +__all__ = ["AutoChunkCodeGen"] + + +def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): + new_shape = "[" + for idx, i in enumerate(shape): + if idx == chunk_dim: + new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) + else: + new_shape += ":" + new_shape += ", " + new_shape = new_shape[:-2] + "]" + return new_shape + + +def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): + input_node = chunk_input[0] + out_shape = get_node_shape(chunk_output) + out_str = str(list(out_shape)) + context = ( + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" + % (out_str, input_node.name, input_node.name, chunk_size) + ) + context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) + return context + + +def _gen_loop_end( + chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list +): + chunk_outputs_name = chunk_outputs.name + chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) + chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape + chunk_slice = _gen_chunk_slice_dim( + chunk_outputs_dim, "chunk_idx", chunk_output_shape + ) + context = " chunk_result%s = %s; %s = None\n" % ( + chunk_slice, + chunk_outputs_name, + chunk_outputs_name, + ) + context += ( + chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + ) + + # determine if its the last use for chunk input + for chunk_input in chunk_inputs + chunk_non_compute_inputs: + if all( + [ + find_idx_by_name(user.name, node_list) <= chunk_outputs_idx + for user in chunk_input.users.keys() + ] + ): + context += "; %s = None" % chunk_input.name + + context += "\n" + return context + + +def _replace_name(context, name_from, name_to): + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] + for p in patterns: + source = p[0] + name_from + p[1] + target = p[0] + name_to + p[1] + if source in context: + context = context.replace(source, target) + return context + + +def _replace_reshape_size(context, node_name, reshape_size_dict): + if node_name not in reshape_size_dict: + return context + for size_name, size_value in reshape_size_dict[node_name].items(): + context = context.replace(size_name, size_value) + return context + + +def emit_code_with_chunk( + body, + nodes, + emit_node_func, + delete_unused_value_func, + chunk_region_search, + chunk_infos, +): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + node_list = list(nodes) + + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + + chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] + + node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) + node_idx = 0 + region_idx = 0 + within_chunk_region = False + + while node_idx < len(node_list): + node = node_list[node_idx] + + if node_idx in chunk_starts: + within_chunk_region = True + region_idx = chunk_starts.index(node_idx) + body.append( + _gen_loop_start( + chunk_inputs[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], + chunk_infos[region_idx]["chunk_size"], + ) + ) + + if within_chunk_region: + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim[0], "chunk_idx", get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + # ones like + if "ones_like" in node.name: + meta_node = chunk_region_search.index_tracer.node_list[node_idx] + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ + "chunk_dim" + ] + if get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node][ + "chunk_dim" + ] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) + body[-1] = _replace_reshape_size( + body[-1], node.name, chunk_infos[region_idx]["reshape_size"] + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) + else: + emit_node_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, chunk_inputs_names) + + if node_idx in chunk_ends: + body.append( + _gen_loop_end( + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], + node_list, + ) + ) + within_chunk_region = False + + node_idx += 1 + + +if CODEGEN_AVAILABLE: + + class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None): + super().__init__() + self.meta_graph = meta_graph + self.max_memory = max_memory + self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) + self.chunk_infos = self.chunk_region_search.search_region() + + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin( + "import colossalai", colossalai + ) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" + + typename = _type_repr(o) + + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") + else: + body.append("\n") + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ( + "" if not node.args else f" = {repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) + return + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk( + body, + nodes, + emit_node, + delete_unused_values, + self.chunk_region_search, + self.chunk_infos, + ) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join( + [f'{wrap_name}("{name}")' for name in wrapped_fns] + ) + else: + wrap_stmts = "" + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue + + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + # print(fn_code) + return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/chunk_codegen.py b/colossalai/autochunk/chunk_codegen.py deleted file mode 100644 index 7a5d06689247..000000000000 --- a/colossalai/autochunk/chunk_codegen.py +++ /dev/null @@ -1,2364 +0,0 @@ -import colossalai -import torch -import copy -from typing import List, Callable, Any, Tuple, Dict, Iterable - -from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name -from torch.fx.graph import ( - _Namespace, - PythonCode, - _custom_builtins, - _is_from_torch, - _format_target, - magic_methods, - CodeGen, - _origin_type_map, - inplace_methods, - _CustomBuiltin, -) -from colossalai.fx.profiler import ( - calculate_fwd_out, - calculate_fwd_tmp, - parameter_size, - activation_size, -) - -CODEGEN_AVAILABLE = True -__all__ = ["ChunkCodeGen"] - - -def _delete_free_var_from_last_use(user_to_last_uses): - for key, value in user_to_last_uses.items(): - for n in value: - if n.op == "placeholder": - user_to_last_uses[key].remove(n) - - -def _get_node_shape(node): - if hasattr(node.meta["tensor_meta"], "shape"): - return node.meta["tensor_meta"].shape - return None - - -def _is_non_compute_node(node): - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -def _is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -def _is_non_compute_node_except_placeholder_output(node): - if any(i in node.op for i in ["get_attr"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -class IndexTracer(object): - def __init__(self, node_list) -> None: - self.node_list = node_list - self.idx_trace_list = self._init_idx_trace_list() - self.idx_trace_equal = [] - self.idx_view_list = {} - self.idx_count = -1 - self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} - - def _init_idx_trace_list(self): - idx_trace_list = [] - for n in self.node_list: - if _get_node_shape(n) != None: - cur_trace = { - "idx": [None for _ in range(len(_get_node_shape(n)))], - "compute": [[] for _ in range(len(_get_node_shape(n)))], - "source": [{} for _ in range(len(_get_node_shape(n)))], - } - else: - cur_trace = {"idx": [], "compute": [], "source": []} - idx_trace_list.append(cur_trace) - return idx_trace_list - - def _add_index(self): - """ - Update the count and return it. To record the idx number. - - Returns: - idx_count: int - """ - self.idx_count += 1 - return self.idx_count - - def _del_dim(self, idx, dim_idx): - self.idx_trace_list[idx]["idx"].pop(dim_idx) - self.idx_trace_list[idx]["compute"].pop(dim_idx) - self.idx_trace_list[idx]["source"].pop(dim_idx) - - def _add_dim(self, node_idx, dim_idx): - self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) - self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) - self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) - - def _transform_index(self, node, node_dim): - node_idx = self._find_idx_trace_from_node(node) - dims = list(range(len(node_idx))) - return dims[node_dim] - - def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): - node_from_dim = self._transform_index(node_from, node_from_dim) - node_to_dim = self._transform_index(node_to, node_to_dim) - node_from_trace = self._find_trace_from_node(node_from) - node_to_trace = self._find_trace_from_node(node_to) - node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] - node_to_trace["compute"][node_to_dim] = copy.deepcopy( - node_from_trace["compute"][node_from_dim] - ) - self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) - - def _inherit_all_computation(self, node_from, node_to): - node_from_compute = self._find_compute_trace_from_node(node_from) - node_to_compute = self._find_compute_trace_from_node(node_to) - assert len(node_from_compute) == len(node_to_compute) - for i in range(len(node_from_compute)): - self._add_source(node_from, i, node_to, i) - node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - - def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): - node_from_dim = self._transform_index(node_from, node_from_dim) - node_from_trace_source = self._find_source_trace_from_node(node_from) - node_to_dim = self._transform_index(node_to, node_to_dim) - node_to_trace_source = self._find_source_trace_from_node(node_to) - node_from_idx = _find_idx_by_name(node_from.name, self.node_list) - if init: - node_to_trace_source[node_to_dim] = {} - # add dim to cur new source - if node_from_idx not in node_to_trace_source[node_to_dim]: - node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] - else: - if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: - node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) - # update inputs source - for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): - if node_idx not in node_to_trace_source[node_to_dim]: - node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) - else: - for d in node_dim: - if d not in node_to_trace_source[node_to_dim][node_idx]: - node_to_trace_source[node_to_dim][node_idx].append(d) - - def _mark_computation_from_node(self, node_from, node_to, exclude=None): - if exclude == None: - exclude = [] - else: - exclude = [self._transform_index(node_to, i) for i in exclude] - node_from_compute = self._find_compute_trace_from_node(node_from) - node_to_compute = self._find_compute_trace_from_node(node_to) - # assert len(node_from_compute) == len(node_to_compute) - for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): - if self._transform_index(node_to, i) in exclude: - continue - self._add_source(node_from, i, node_to, i) - for j in node_from_compute[i]: - if j not in node_to_compute[i]: - node_to_compute[i].append(j) - - def _mark_idx_equal(self, node1, dim1, node2, dim2): - """ - Mark 2 index to be equal. - - Args: - idx1 (int): index count. - idx2 (int): index count. - """ - # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) - # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) - # if node1_idx > node2_idx: - # self._add_source(node2, dim2, node1, dim1) - # else: - # self._add_source(node1, dim1, node2, dim2) - - def _mark_computation(self, node, idx, dim): - """ - Mark some dims of node as computed. - - Args: - node (node) - idx (int): node index - dim (list or int): dims to be marked as computed - """ - if isinstance(dim, int): - dim = [dim] - dims = list(range(len(_get_node_shape(node)))) - for d in dim: - cur_dim = dims[d] - if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: - self.idx_trace_list[idx]["compute"][cur_dim].append(idx) - - def _find_trace_from_node(self, node): - """ - Find node idx and compute trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - node_dict = self.idx_trace_list[node_idx] - return node_dict - - def _find_source_trace_from_node(self, node): - """ - Find node source trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - node_dict = self.idx_trace_list[node_idx] - return node_dict["source"] - - def _find_idx_trace_from_node(self, node): - """ - Find node idx trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - return self.idx_trace_list[node_idx]["idx"] - - def _find_compute_trace_from_node(self, node): - """ - Find node compute trace by the node. - - Args: - node (node) - Returns: - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - return self.idx_trace_list[node_idx]["compute"] - - def _assign_index_as_input(self, node, node_idx, input_node=None): - """ - Assign node's trace as its input node. - - Args: - node (node) - node_idx (int) - """ - if input_node == None: - input_node = node.args[0] - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] - - new_idx_trace = copy.deepcopy(input_node_idx_trace) - self.idx_trace_list[node_idx]["idx"] = new_idx_trace - - self._inherit_all_computation(input_node, node) - - def _assign_all_index(self, node, node_idx): - """ - Add new index for all node's dims. - - Args: - node (node) - node_idx (int) - """ - shape = node.meta["tensor_meta"].shape - new_trace = [] - for _ in shape: - new_trace.append(self._add_index()) - self.idx_trace_list[node_idx]["idx"] = new_trace - - def _assign_transpose_index(self, node, node_idx): - """ - Assign index for transpose op. - 1. swap input's dim according to transpose args - 2. inherit input's computation - - Args: - node (node) - node_idx (int) - """ - input_node = node.args[0] - tranpose_dim = node.args[1:] - - self._assign_index_as_input(node, node_idx, input_node) - self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) - self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) - - def _assign_permute_index(self, node, node_idx): - """ - Assign index for permute op. - 1. swap input's dim according to permute args - 2. inherit input's computation - - Args: - node (node) - node_idx (int) - """ - permute_dim = node.args[1:] - input_node = node.args[0] - - self._assign_index_as_input(node, node_idx, input_node) - for idx, d in enumerate(permute_dim): - self._inherit_index(input_node, d, node, idx) - - def _assign_linear_index(self, node, node_idx): - """ - Assign index for linear op. - 1. copy trace from input node and change last index accroding to weight - 2. mark equal for input node last index, weight first dim and bias dim. - 3. inherit input's computation, mark computation for last dim. - - Args: - node (node) - node_idx (int) - """ - if len(node.args) == 2: - input_node, weight = node.args - bias = None - else: - input_node, weight, bias = node.args - - self._assign_index_as_input(node, node_idx) - self._inherit_index(weight, 1, node, -1) - - self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(input_node, -1, weight, 0) - - if bias: - self._mark_idx_equal(input_node, -1, bias, 0) - - def _assign_matmul_index(self, node, node_idx): - """ - Assign index for matmul op. - 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) - 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. - 3. inherit matmul_left and matmul_right computation, mark computation for last dim. - - Args: - node (node) - node_idx (int) - """ - matmul_left, matmul_right = node.args - - assert len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right)) - self._assign_index_as_input(node, node_idx, matmul_left) - self._inherit_index(matmul_right, -1, node, -1) - - self._mark_computation_from_node(matmul_right, node, [-1, -2]) - self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(matmul_left, -1, matmul_right, -2) - - def _assign_layernorm_index(self, node, idx): - """ - Assign index for layernorm op. - 1. assign index as input node - 2. inherit computation and mark last 2 dims as computed. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [-1]) - - def _assign_elementwise_index(self, node, idx): - """ - Assign index for element-wise op (eg. relu sigmoid add mul). - 1. assign index as input node - 2. inherit computation from all input nodes. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - nodes_in = [] - for node_in in node.args: - if type(node_in) == type(node): - nodes_in.append(node_in) - self._mark_computation_from_node(node_in, node) - assert len(nodes_in) <= 2 - if len(nodes_in) == 2: - node_in0_shape = _get_node_shape(nodes_in[0]) - node_in1_shape = _get_node_shape(nodes_in[1]) - for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): - if node_in0_shape[i] == node_in1_shape[i]: - self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) - - def _assgin_no_change_index(self, node, idx): - self._assign_index_as_input(node, idx) - for node_in in node.args: - if type(node_in) == type(node): - self._mark_computation_from_node(node_in, node) - - def _assign_einsum_index(self, node, idx): - """ - Assign index for einsum op. - - Args: - node (node) - node_idx (int) - """ - patterns = node.args[0] - input_nodes = node.args[1:] - - patterns = patterns.replace(" ", "") - left, right = patterns.split("->") - left = left.split(",") - - all_index = [] - for i in left: - for c in i: - all_index.append(c) - all_index = set(all_index) - free_index = set([i for i in right]) - sum_index = all_index - free_index - - for right_idx, right_indice in enumerate(right): - for left_idx, left_str in enumerate(left): - if right_indice in left_str: - source_idx = left_str.index(right_indice) - self._inherit_index( - input_nodes[left_idx], source_idx, node, right_idx - ) - - # for i in sum_index: - # for left_idx, left_str in enumerate(left): - # if i in left_str: - # self._mark_computation(node, idx, left_str.index(i)) - # break - - def _assign_softmax_index(self, node, idx): - """ - Assign index for softmax op. - 1. assign index as input node - 2. inherit computation and mark softmax dim as computed. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [node.kwargs["dim"]]) - - def _assign_unsqueeze_index(self, node, node_idx): - """ - Assign index for unsqueeze op. - 1. assign new index for unsqueeze dim - - Args: - node (node) - node_idx (int) - """ - self._del_dim(node_idx, -1) - self._assign_index_as_input(node, node_idx) - self._add_dim(node_idx, node.args[1]) - - def _assign_dropout_index(self, node, node_idx): - """ - Assign index for unsqueeze op. - 1. assign new index for unsqueeze dim - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, node_idx) - - def _assign_ones_like_index(self, node, node_idx): - """ - Assign index for oneslike op. - 1. assign new index for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_all_index(node, node_idx) - - def _assign_view_reshape_index(self, node, node_idx): - """ - Assign index for view and reshape op. - 1. get origin shape and target shape by meta info. - 2. compute the real value of -1 in target shape. - 3. determine changed dim, and assgin index for generated dim. - 4. log changed dim and generated dim for restore - 5. inherit computation. - 6. TODO: look into view list to see whether the view is associated with other, - if so assgin equal dim according to previous view. - - Args: - node (node) - node_idx (int) - """ - # get data, turn into number - origin_node = node.args[0] - origin_shape = origin_node.meta["tensor_meta"].shape - target_shape = [] - for i in range(1, len(node.args)): - if isinstance(node.args[i], int): - target_shape.append(node.args[i]) - else: - target_shape.append(node.args[i].meta["fwd_out"][0]) - - # compute the value of -1 - if -1 in target_shape: - origin_product = 1 - for i in origin_shape: - origin_product *= i - target_product = -1 - for i in target_shape: - target_product *= i - shape_idx = target_shape.index(-1) - target_shape[shape_idx] = origin_product // target_product - - # determine changed dim - len_diff = len(origin_shape) - len(target_shape) - if len_diff == 1: - # dim merge - dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] - dim_to = [dim_equal.index(False)] - dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] - self._add_dim(node_idx, -1) - elif len_diff == -1: - # dim expand - dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] - dim_from = [dim_equal.index(False)] - dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] - self._del_dim(node_idx, -1) - else: - raise NotImplementedError( - "shape" - + str(origin_shape) - + "and" - + str(target_shape) - + "view not implemented" - ) - - # get new index - origin_trace = self._find_idx_trace_from_node(origin_node) - self._assign_index_as_input(node, node_idx, origin_node) - dim_from.reverse() - for i in dim_from: - self._del_dim(node_idx, i) - for i in dim_to: - self._add_dim(node_idx, i) - - # inherit computation - compute_log = self._find_compute_trace_from_node(origin_node) - for i in dim_from: - if origin_trace[i] in compute_log: - for j in dim_to: - self._mark_computation(node, node_idx, [j]) - break - - # log view, not used now - view_dict = { - "idx_from": [origin_trace[i] for i in dim_from], - "dim_from": dim_from, - "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], - "dim_to": dim_to, - } - self.idx_view_list[node] = view_dict - - def _merge_equal_idx(self): - idx_equal = copy.deepcopy(self.idx_trace_equal) - idx_equal.reverse() - for idx in idx_equal: - merge_to = min(idx) - merge_from = max(idx) - for trace in self.idx_trace_list: - if merge_from in trace["idx"]: - trace["idx"] = [ - merge_to if i == merge_from else i for i in trace["idx"] - ] - - def trace_index(self): - for idx, node in enumerate(self.node_list): - if node.op == "placeholder": - self._assign_all_index(node, idx) - elif node.op == "call_method": - if "transpose" in node.name: - self._assign_transpose_index(node, idx) - elif "permute" in node.name: - self._assign_permute_index(node, idx) - elif "view" in node.name or "reshape" in node.name: - self._assign_view_reshape_index(node, idx) - elif "unsqueeze" in node.name: - self._assign_unsqueeze_index(node, idx) - elif any(i in node.name for i in ["to", "contiguous"]): - self._assgin_no_change_index(node, idx) - else: - raise NotImplementedError(node.name, "method not implemented yet!") - elif node.op == "call_function": - if "linear" in node.name: - self._assign_linear_index(node, idx) - elif "matmul" in node.name: - self._assign_matmul_index(node, idx) - elif "softmax" in node.name: - self._assign_softmax_index(node, idx) - elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): - self._assign_elementwise_index(node, idx) - elif "ones_like" in node.name: - self._assign_ones_like_index(node, idx) - elif "dropout" in node.name: - self._assign_dropout_index(node, idx) - elif "einsum" in node.name: - self._assign_einsum_index(node, idx) - elif "getattr" in node.name: - continue # get attr like shape - elif "getitem" in node.name: - continue # get item in list - else: - raise NotImplementedError( - node.name, "function not implemented yet!" - ) - elif node.op == "call_module": - if any(n in node.name for n in ["layernorm", "norm"]): - self._assign_layernorm_index(node, idx) - else: - raise NotImplementedError(node.name, "module not implemented yet!") - elif node.op == "get_attr": - self._assign_all_index(node, idx) # get param - elif node.op == "output": - continue - else: - raise NotImplementedError(node.op, "op not implemented yet!") - # self._merge_equal_idx() - - def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): - """ - Check 2 given index: one index should be source of the other - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - start_node_idx = _find_idx_by_name(start_node.name, self.node_list) - end_node_trace = self._find_trace_from_node(end_node) - end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) - for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and start_dim in node_dim: - return True - # it means we meet a node outside the loop, and the node is not input node - if node_idx < start_idx: - return False - return False - - def check_index_compute(self, start_idx, end_dim, end_node, end_idx): - """ - Check 2 given index: check they haven't been computed in the source trace. - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - end_node_trace = self._find_trace_from_node(end_node) - end_node_compute = end_node_trace["compute"][end_dim] - if any(start_idx <= i <= end_idx for i in end_node_compute): - return False - return True - - def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self._find_source_trace_from_node(node_from) - dim_source = node_from_source[node_from_dim] - node_to_idx = _find_idx_by_name(node_to.name, self.node_list) - for k, v in dim_source.items(): - if k == node_to_idx: - return v - return None - - def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(_get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): - return node_dim - return None - - def check_index_duplicate(self, chunk_infos, return_dim=False): - input_dim_after_node = {} - for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): - for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) - if inherit_dim: - input_dim_after_node[k] = inherit_dim - - for node in self.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: - if _is_non_compute_node_except_placeholder(node): - continue - count = 0 - duplicate_dims = [] - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(_get_node_shape(node))): - duplicate_dim = [] - duplicate_flag = False - dim_source = node_trace_source[node_dim] - for k, v in dim_source.items(): - if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] in v: - duplicate_flag = True - duplicate_dim.append((k, v)) - duplicate_dims.append(duplicate_dim) - if duplicate_flag: - count += 1 - - if count > 1: - if return_dim: - return False, duplicate_dims - else: - return False - if return_dim: - return True, None - else: - return True - - def _assgin_single_node_flow( - self, - arg_node, - start_idx, - end_idx, - cur_node_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ): - arg_idx = _find_idx_by_name(arg_node.name, self.node_list) - # arg in chunk range or be inputs - if not (start_idx <= arg_idx < end_idx): - return True - - # find arg dim - if cur_node_dim is not None: - # dim is computed - if arg_idx in cur_node_compute[cur_node_dim]: - return False - if arg_idx not in cur_node_source[cur_node_dim]: - arg_dim = None - else: - arg_dim = cur_node_source[cur_node_dim][arg_idx][0] - else: - arg_dim = None - - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) - - # if already in node_info, arg dim must be same - if arg_node in all_node_info: - if all_node_info[arg_node]["chunk_dim"] != arg_dim: - return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) - # else add it to list - else: - all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} - - next_node_list.append(arg_node) - return True - - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - cur_node_list = [self.node_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} - - while len(cur_node_list) > 0: - next_node_list = [] - - for cur_node in cur_node_list: - # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] - cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) - if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node(cur_node) - cur_node_source = self._find_source_trace_from_node(cur_node) - else: - cur_node_compute = cur_node_source = None - - # get all valid args - arg_list = [] - for arg in cur_node.args: - if type(arg) != type(cur_node): - continue - if _is_non_compute_node(arg): - continue - arg_list.append(arg) - flow_flag = self._assgin_single_node_flow( - arg, - start_idx, - end_idx, - cur_node_chunk_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ) - if flow_flag == False: - return None - - if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): - for arg in arg_list: - if not ( - start_idx - <= _find_idx_by_name(arg.name, self.node_list) - < end_idx - ): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = _get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif "einsum" in cur_node.name: - pass - elif "matmul" in cur_node.name: - pass - else: - raise NotImplementedError() - cur_node_list = next_node_list - - inputs_dim = [] - remove_inputs = [] - for input_node in inputs: - input_dict = {} - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]["chunk_dim"] - if chunk_dim is not None: - user_source = self._find_source_trace_from_node(user)[chunk_dim] - if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] - else: - return None - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - for i in remove_inputs: - if i in inputs: - inputs.remove(i) - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop - # get all possible prepose nodes - maybe_prepose_nodes = [] - for node, node_info in all_node_info.items(): - if node_info["chunk_dim"] is None: - maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, self.node_list), - reverse=True, - ) # from last node to first node - prepose_nodes = [] - # set every node as root, search its args, if all legal, turn root and args as prepose nodes - while len(maybe_prepose_nodes) > 0: - tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] - tmp_cur_related_prepose_nodes = [] - prepose_flag = True - - # loop cur node's all arg until out of chunk - while len(tmp_cur_prepose_nodes) > 0: - if prepose_flag == False: - break - tmp_next_prepose_nodes = [] - tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) - for cur_prepose_node in tmp_cur_prepose_nodes: - if prepose_flag == False: - break - for cur_prepose_node_arg in cur_prepose_node.args: - if type(cur_prepose_node_arg) != type(cur_prepose_node): - continue - # out of loop - if not ( - start_idx - <= _find_idx_by_name( - cur_prepose_node_arg.name, self.node_list - ) - < end_idx - ): - continue - # compute op in loop - elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - else: - prepose_flag = False - break - # non compute op - else: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - tmp_cur_prepose_nodes = tmp_next_prepose_nodes - - if prepose_flag == False: - maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) - continue - else: - for n in tmp_cur_related_prepose_nodes: - if n not in prepose_nodes: - prepose_nodes.append(n) - if n in maybe_prepose_nodes: - maybe_prepose_nodes.remove(n) - # sort by index - prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list)) - chunk_info["args"]["prepose_nodes"] = prepose_nodes - - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in prepose_nodes: - chunk_node_list.remove(n) - non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - # reassgin reshape size, some size may have changed due to chunk - chunk_info = self._reassgin_reshape_size(chunk_info) - - return chunk_info - - def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info["region"] - reshape_size = {} - chunk_shape = _get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: - if any(i in node.name for i in ["reshape", "view"]): - reshape_args = node.args[1:] - reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] - reshape_size[node.name] = {} - for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log["dim_to"]: - continue - if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) - chunk_info["reshape_size"] = reshape_size - return chunk_info - - def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.node_list))} - - chunk_region_start = chunk_info["region"][0] - chunk_region_end = chunk_info["region"][1] - chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] - chunk_prepose_nodes_idx = [ - _find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes - ] - # put prepose nodes ahead - for idx, n in enumerate(chunk_prepose_nodes): - n_idx = chunk_prepose_nodes_idx[idx] - reorder_map[n_idx] = chunk_region_start + idx - # put other nodes after prepose nodes - for n in self.node_list[chunk_region_start : chunk_region_end + 1]: - if n in chunk_prepose_nodes: - continue - n_idx = _find_idx_by_name(n.name, self.node_list) - pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) - reorder_map[n_idx] = n_idx + pos - - return reorder_map - - def _reorder_chunk_info(self, chunk_info, reorder_map): - # update chunk info - chunk_info["region"] = ( - chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), - chunk_info["region"][1], - ) - new_inputs_dim = [] - for idx, input_dim in enumerate(chunk_info["inputs_dim"]): - new_input_dim = {} - for k, v in input_dim.items(): - new_input_dim[reorder_map[k]] = v - new_inputs_dim.append(new_input_dim) - chunk_info["inputs_dim"] = new_inputs_dim - return chunk_info - - def _update_all_reorder_map(self, reorder_map): - for origin_idx, map_idx in self.all_reorder_map.items(): - self.all_reorder_map[origin_idx] = reorder_map[map_idx] - - def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.node_list[old_idx] - self.node_list = new_node_list - - def _reorder_idx_trace(self, reorder_map): - # reorder list - new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] - for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] - self.idx_trace_list = new_idx_trace_list - # update compute - for idx_trace in self.idx_trace_list: - compute = idx_trace["compute"] - for dim_compute in compute: - for idx, i in enumerate(dim_compute): - dim_compute[idx] = reorder_map[i] - # update source - for idx_trace in self.idx_trace_list: - source = idx_trace["source"] - for dim_idx, dim_source in enumerate(source): - new_dim_source = {} - for k, v in dim_source.items(): - new_dim_source[reorder_map[k]] = v - source[dim_idx] = new_dim_source - - def reorder_all(self, chunk_info): - if chunk_info is None: - return chunk_info - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return chunk_info - reorder_map = self._get_reorder_map(chunk_info) - self._update_all_reorder_map(reorder_map) - self._reorder_idx_trace(reorder_map) - self._reorder_self_node_list(reorder_map) - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return chunk_info - - def reorder_node_list(self, node_list): - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in self.all_reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - return new_node_list - - def tmp_reorder(self, node_list, chunk_info): - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return node_list, chunk_info - reorder_map = self._get_reorder_map(chunk_info) - - # new tmp node list - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return new_node_list, chunk_info - - -class MemoryEstimator(object): - def __init__(self, index_tracer: IndexTracer) -> None: - pass - - def _get_meta_node_size(self, x): - x = x.meta["tensor_meta"] - x = x.numel * torch.tensor([], dtype=x.dtype).element_size() - return x - - def _get_output_node(self, n): - fwd_out = { - x.uuid: x - for x in n.meta["fwd_out"] - if isinstance(x, torch.Tensor) and hasattr(x, "uuid") - } - out_size = activation_size(fwd_out) - out_node = [n.name] if out_size > 0 else [] - # if any(i in n.name for i in ['transpose', 'permute', 'view']): - # out_size = 0 - return out_size, out_node - - def _get_output_node_size(self, n): - return self._get_output_node(n)[0] - - def _add_active_node(self, n, active_list): - new_active = self._get_output_node(n)[1] - if n.op == "placeholder": - new_active.append(n.name) - for i in new_active: - if i not in active_list: - active_list.append(i) - - def _get_delete_node(self, user, user_to_last_uses, to_keep=None): - delete_size = 0 - delete_node = [] - if user.op not in ("output",): - nodes_to_delete = user_to_last_uses.get(user, []) - if to_keep is not None: - keep_list = [] - for n in nodes_to_delete: - if n.name in to_keep: - keep_list.append(n) - for n in keep_list: - if n in nodes_to_delete: - nodes_to_delete.remove(n) - if len(nodes_to_delete): - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - for i in range(len(out_node)): - if out_node[i][0] > 0: - delete_node.append(out_node[i][1][0]) - elif nodes_to_delete[i].op == "placeholder": - delete_node.append(nodes_to_delete[i].name) - # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): - # delete_node.append(nodes_to_delete[i].name) - return delete_size, delete_node - - def _get_delete_node_size(self, user, user_to_last_uses, to_keep): - return self._get_delete_node(user, user_to_last_uses, to_keep)[0] - - def _remove_deactive_node(self, user, user_to_last_uses, active_list): - delete_node = self._get_delete_node(user, user_to_last_uses)[1] - for i in delete_node: - if i in active_list: - active_list.remove(i) - - def _get_chunk_inputs_size( - self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx - ): - nodes_to_delete = [] - for chunk_input in chunk_inputs + chunk_inputs_non_chunk: - chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [ - _find_idx_by_name(i.name, node_list) for i in chunk_input_users - ] - if all(i <= chunk_end_idx for i in chunk_input_users_idx): - if chunk_input not in nodes_to_delete: - nodes_to_delete.append(chunk_input) - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - return delete_size - - def _get_last_usr(self, nodes): - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - return user_to_last_uses - - def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): - mem = 0 - not_contiguous_ops = ["permute"] - inherit_contiguous_ops = ["transpose", "view"] - - if node.op == "call_function" and any( - n in node.name for n in ["matmul", "reshape"] - ): - for n in node.args: - if n in not_contiguous_list: - # matmul won't change origin tensor, but create a tmp copy - mem += self._get_output_node_size(n) - elif node.op == "call_module": - for n in node.args: - if n in not_contiguous_list: - # module will just make origin tensor to contiguous - if delete: - not_contiguous_list.remove(n) - elif node.op == "call_method" and any( - i in node.name for i in not_contiguous_ops - ): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - return mem - - def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): - if node not in chunk_node_dim: - return 1.0 - node_shape = _get_node_shape(node) - chunk_dim = chunk_node_dim[node]["chunk_dim"] - if chunk_dim is None: - return 1.0 - else: - return float(chunk_size) / node_shape[chunk_dim] - - def _get_chunk_delete_node_size( - self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names - ): - # if any(j in user.name for j in ['transpose', 'permute', 'view']): - # return 0 - if user.op in ("placeholder", "output"): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - delete_size = 0 - for n in nodes_to_delete: - if n.name in chunk_inputs_names: - continue - delete_size += self._get_output_node_size(n) * chunk_ratio - return delete_size - - def _print_mem_log(self, log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end="") - if (idx + 1) % 3 == 0: - print("") - print("\n") - - def _print_compute_op_mem_log(self, log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - if n.op in ["placeholder", "get_attr", "output"]: - continue - if any(i in n.name for i in ["getitem", "getattr"]): - continue - print("%s:%.2f \t" % (n.name, l), end="") - if (idx + 1) % 3 == 0: - print("") - print("\n") - - def estimate_chunk_inference_mem( - self, - node_list, - chunk_infos=None, - print_mem=False, - ): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - active_node_list = [] - active_node_list_log = [] - not_contiguous_list = [] - user_to_last_uses = self._get_last_usr(node_list) - user_to_last_uses_no_free_var = self._get_last_usr(node_list) - _delete_free_var_from_last_use(user_to_last_uses_no_free_var) - - use_chunk = True if chunk_infos is not None else False - chunk_within = False - chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem - chunk_inputs_names = [] - - if use_chunk: - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] - chunk_outputs = [i["outputs"][0] for i in chunk_infos] - chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] - chunk_sizes = [ - i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos - ] - - for idx, node in enumerate(node_list): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if use_chunk and idx in chunk_starts: - chunk_within = True - chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size( - chunk_outputs[chunk_region_idx] - ) / (1024**2) - - # determine chunk ratio for current node - if chunk_within: - chunk_ratio = self._get_chunk_ratio( - node, - chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx], - ) - - # if node is placeholder, just add the size of the node - if node.op == "placeholder": - act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) - act_memory_peak_log.append(act_memory) - # skip output - elif node.op == "output": - continue - # no change for non compute node - elif _is_non_compute_node_except_placeholder(node): - act_memory_peak_log.append(act_memory) - # node is a compute op - # calculate tmp, output node and delete node memory - else: - # forward memory - # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose - act_memory += ( - self._get_contiguous_memory(node, not_contiguous_list) - * chunk_ratio - / (1024**2) - ) - act_memory += ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= ( - self._get_contiguous_memory(node, not_contiguous_list, delete=True) - * chunk_ratio - / (1024**2) - ) - # delete unused vars not in chunk_input_list - # we can't delete input nodes until chunk ends - if chunk_within: - act_memory -= self._get_chunk_delete_node_size( - node, - user_to_last_uses_no_free_var, - chunk_ratio, - chunk_inputs_names, - ) / (1024**2) - else: - act_memory -= self._get_delete_node_size( - node, user_to_last_uses_no_free_var, chunk_inputs_names - ) / (1024**2) - - # log active node, only effective without chunk - self._add_active_node(node, active_node_list) - self._remove_deactive_node(node, user_to_last_uses, active_node_list) - - # if node in chunk end nodes, restore chunk settings - if use_chunk and idx in chunk_ends: - act_memory -= ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) - act_memory -= self._get_chunk_inputs_size( - chunk_inputs[chunk_region_idx], - chunk_inputs_non_chunk[chunk_region_idx], - node_list, - chunk_regions[chunk_region_idx][1], - ) / (1024**2) - chunk_within = False - chunk_ratio = 1 - chunk_region_idx = None - - act_memory_after_node_log.append(act_memory) - active_node_list_log.append(copy.deepcopy(active_node_list)) - - if print_mem: - print("with chunk" if use_chunk else "without chunk") - # self._print_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_mem_log(act_memory_after_node_log, node_list, "after") - self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_compute_op_mem_log( - # act_memory_after_node_log, node_list, "after" - # ) - - # param_memory = parameter_size(gm) - # all_memory = act_memory + param_memory - return act_memory_peak_log, act_memory_after_node_log, active_node_list_log - - -class ChunkSelector(object): - def __init__( - self, - index_tracer: IndexTracer, - memory_estimator: MemoryEstimator, - max_memory=None, - ): - self.index_tracer = index_tracer - self.memory_estimator = memory_estimator - if max_memory is not None: - self.stratge = "fit_memory" - self.max_memory = max_memory # MB - else: - self.stratge = "min_memory" - - def _select_best_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - if self.stratge == "min_memory": - best_region = self._select_min_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) - elif self.stratge == "fit_memory": - best_region = self._select_fit_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) - else: - raise RuntimeError() - return best_region - - def _select_fit_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - # stop chunk if max memory satisfy memory limit - if max(mem_peak) < self.max_memory: - return None - - # remove illegal regions - illegal_regions = [] - for i in possible_chunk_regions: - if not self._is_legal_region(i, chunk_infos): - illegal_regions.append(i) - for i in illegal_regions: - if i in possible_chunk_regions: - possible_chunk_regions.remove(i) - - if len(possible_chunk_regions) == 0: - return None - - # get mem for chunk region - regions_dict = [] - for region in possible_chunk_regions: - cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( - self.index_tracer.node_list, cur_region - ) - cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] - cur_chunk_region_max_peak = max(cur_chunk_region_peak) - if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) - # no region found - if len(regions_dict) == 0: - raise RuntimeError("Search failed. Try a larger memory threshold.") - - # select the min chunk len - chunk_len = [i["chunk_len"] for i in regions_dict] - best_region_idx = chunk_len.index(min(chunk_len)) - best_region = regions_dict[best_region_idx] - - # get max chunk size - best_region = self._get_fit_chunk_size(best_region, chunk_infos) - return best_region - - def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): - chunk_size = 1 - reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] - reorder_chunk_info["chunk_size"] = chunk_size - cur_chunk_max_mem = 0 - # search a region - while cur_chunk_max_mem < self.max_memory: - chunk_size *= 2 - reorder_chunk_info["chunk_size"] = chunk_size - cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[ - reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] - + 1 - ] - ) - # search exact size - chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search( - chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos - ) - return chunk_info - - def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): - if l >= 16: - gap = 4 - else: - gap = 1 - chunk_info = chunk_region_dict["reorder_chunk_info"] - while r >= l + gap: - mid = int((l + r) / 2 + 0.5) - chunk_info["chunk_size"] = mid - cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] - ) - if cur_chunk_max_mem >= self.max_memory: - r = mid - gap - else: - l = mid + gap - return l - - def _get_compute_node_num(self, start, end): - count = 0 - for i in self.index_tracer.node_list[start : end + 1]: - if not _is_non_compute_node(i): - count += 1 - return count - - def _select_min_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - # remove illegal regions - illegal_regions = [] - for i in possible_chunk_regions: - if not self._is_legal_region(i, chunk_infos): - illegal_regions.append(i) - for i in illegal_regions: - if i in possible_chunk_regions: - possible_chunk_regions.remove(i) - - if len(possible_chunk_regions) == 0: - return None - - # get mem for chunk region - regions_dict = [] - for region in possible_chunk_regions: - cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( - self.index_tracer.node_list, cur_region - ) - cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] - cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) - - # select the min mem - chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] - best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) - best_region = regions_dict[best_region_idx]["chunk_info"] - if best_region is not None: - best_region["chunk_size"] = 1 - return best_region - - def _is_legal_region(self, cur_chunk_info, chunk_infos): - (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] - if cur_chunk_info in chunk_infos: - return False - if chunk_region_end < chunk_region_start: - return False - for i in chunk_infos: - region = i["region"] - if not ( - (chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0]) - ): - return False - return True - - -class ChunkRegionSearch(object): - def __init__(self, gm, max_memory=None) -> None: - self.gm = gm - self.index_tracer = IndexTracer(list(gm.graph.nodes)) - self.index_tracer.trace_index() - self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, max_memory=max_memory - ) - - def _find_peak_node(self, mem_peak): - max_value = max(mem_peak) - max_idx = mem_peak.index(max_value) - return max_idx - - def _get_free_var(self): - free_var_idx = [] - for idx, n in enumerate(self.index_tracer.node_list): - if n.op == "placeholder": - free_var_idx.append(idx) - return free_var_idx - - def _get_min_free_var(self, active_node_list, free_vars): - min_len = 999 - for idx, n in enumerate(active_node_list): - if idx in free_vars: - continue - if len(n) < min_len: - min_len = len(n) - return min_len - - def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): - free_vars = self._get_free_var() - free_var_num = len(free_vars) - active_node_num = [len(i) for i in active_node] - min_active_node_num = min(active_node_num[free_var_num:]) - threshold = max(free_var_num, min_active_node_num) - - # from peak_node to free_var - inside_flag = False - chunk_region_start = free_var_num - for i in range(peak_node, -1, -1): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: - chunk_region_start = i + 1 - break - - # from peak_node to len-2 - inside_flag = False - chunk_region_end = len(active_node) - 1 - for i in range(peak_node, len(active_node)): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: - chunk_region_end = i - break - - for i in chunk_regions: - region = i["region"] - if chunk_region_start >= region[0] and chunk_region_end <= region[1]: - return None - elif ( - region[0] <= chunk_region_start <= region[1] - and chunk_region_end > region[1] - ): - chunk_region_start = region[1] + 1 - elif ( - region[0] <= chunk_region_end <= region[1] - and chunk_region_start < region[0] - ): - chunk_region_end = region[0] - 1 - return chunk_region_start, chunk_region_end - - def _is_not_compute(self, trace, chunk_range, dim_idx): - if trace["idx"][dim_idx] not in trace["compute"]: - return True - if trace["idx"][dim_idx] in trace["compute"] and all( - i < chunk_range[0] or i > chunk_range[1] - for i in trace["compute"][trace["idx"][dim_idx]] - ): - return True - return False - - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): - start_traces = input_trace[start_idx] - end_trace = output_trace[end_idx] - end_node = self.index_tracer.node_list[end_idx] - chunk_infos = [] - for end_dim, _ in enumerate(end_trace["idx"]): - if len(start_traces) > 1: - continue - for start_node, start_trace in start_traces.items(): - for start_dim, _ in enumerate(start_trace["idx"]): - # dim size cannot be 1 - if ( - _get_node_shape(end_node)[end_dim] == 1 - or _get_node_shape(start_node)[start_dim] == 1 - ): - continue - # check index source align - if not self.index_tracer.check_index_source( - start_dim, start_node, start_idx, end_dim, end_node - ): - continue - # check index copmute - if not self.index_tracer.check_index_compute( - start_idx, end_dim, end_node, end_idx - ): - continue - # flow search - chunk_info = self.index_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim - ) - if chunk_info is None: - continue - # check index copmute - if not self.index_tracer.check_index_duplicate(chunk_info): - continue - chunk_infos.append(chunk_info) - return chunk_infos - - def _search_possible_chunk_regions(self, max_chunk_region, peak_node): - possible_chunk_region = [] - output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) - input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.index_tracer.node_list): - cur_trace = {} - for arg in n.args: - if type(arg) == type(n) and not _is_non_compute_node_except_placeholder( - arg - ): - cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) - input_trace.append(cur_trace) - - for start_idx in range(max_chunk_region[0], peak_node + 1): - for end_idx in range(peak_node, max_chunk_region[1] + 1): - # skip non compute nodes - if _is_non_compute_node( - self.index_tracer.node_list[start_idx] - ) or _is_non_compute_node(self.index_tracer.node_list[end_idx]): - continue - - # select free dim - chunk_info = self._find_free_dim( - input_trace, output_trace, start_idx, end_idx - ) - if len(chunk_info) > 0: - possible_chunk_region.extend(chunk_info) - return possible_chunk_region - - def _step_search(self, mem_peak, active_node, chunk_regions): - peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region( - active_node, peak_node, chunk_regions - ) - if max_chunk_region == None: - return None - possible_chunk_regions = self._search_possible_chunk_regions( - max_chunk_region, peak_node - ) - best_chunk_region = self.chunk_selector._select_best_chunk_region( - possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak - ) - best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) - return best_chunk_region - - def _stop_search(self, init_mem_peak, mem_peak): - sorted_init_mem_peak = sorted(init_mem_peak) - if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: - return True - return False - - def search_region(self): - chunk_infos = [] - ( - init_mem_peak, - _, - active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list - ) - mem_peak = init_mem_peak - - while True: - chunk_info = self._step_search(mem_peak, active_node, chunk_infos) - if chunk_info is None: - break - chunk_infos.append(chunk_info) - - ( - mem_peak, - _, - active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos - ) - if self._stop_search(init_mem_peak, mem_peak): - break - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True - ) - return chunk_infos - - -def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): - new_shape = "[" - for idx, i in enumerate(shape): - if idx == chunk_dim: - new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) - else: - new_shape += ":" - new_shape += ", " - new_shape = new_shape[:-2] + "]" - return new_shape - - -def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): - input_node = chunk_input[0] - out_shape = _get_node_shape(chunk_output) - out_str = str(list(out_shape)) - context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" - % (out_str, input_node.name, input_node.name, chunk_size) - ) - context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) - return context - - -def _gen_loop_end( - chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list -): - chunk_outputs_name = chunk_outputs.name - chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) - chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim( - chunk_outputs_dim, "chunk_idx", chunk_output_shape - ) - context = " chunk_result%s = %s; %s = None\n" % ( - chunk_slice, - chunk_outputs_name, - chunk_outputs_name, - ) - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) - - # determine if its the last use for chunk input - for chunk_input in chunk_inputs + chunk_non_compute_inputs: - if all( - [ - _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in chunk_input.users.keys() - ] - ): - context += "; %s = None" % chunk_input.name - - context += "\n" - return context - - -def _find_chunk_all_input_nodes(nodes: List[Node]): - """ - Find non-compute input and output node names. - input nodes are nodes used in the list - output nodes are nodes will use nodes in the list - """ - input_nodes = [] - for node in nodes: - for input_node in node._input_nodes.keys(): - if input_node not in nodes and input_node not in input_nodes: - input_nodes.append(input_node) - return input_nodes - - -def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): - """ - Find non-compute input and output node names. - input nodes are nodes used in the list - output nodes are nodes will use nodes in the list - """ - input_nodes = [] - output_nodes = [] - - # if a node has an input node which is not in the node list - # we treat that input node as the input of the checkpoint function - for node in nodes: - for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - and not _is_non_compute_node_except_placeholder(input_node) - ): - input_nodes.append(input_node) - - # if a node has a user node which is not in the node list - # we treat that user node as the node receiving the current node output - for node in nodes: - for output_node in node.users.keys(): - if ( - output_node not in nodes - and node not in output_nodes - and not _is_non_compute_node_except_placeholder_output(output_node) - ): - output_nodes.append(node) - - return input_nodes, output_nodes - - -def _find_idx_by_name(name, nodes_list): - for idx, node in enumerate(nodes_list): - if node.name == name: - return idx - raise RuntimeError("name %s not found in node list" % name) - - -def _replace_name(context, name_from, name_to): - patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] - for p in patterns: - source = p[0] + name_from + p[1] - target = p[0] + name_to + p[1] - if source in context: - context = context.replace(source, target) - return context - - -def _replace_reshape_size(context, node_name, reshape_size_dict): - if node_name not in reshape_size_dict: - return context - for size_name, size_value in reshape_size_dict[node_name].items(): - context = context.replace(size_name, size_value) - return context - - -def emit_code_with_chunk( - body, - nodes, - emit_node_func, - delete_unused_value_func, - chunk_region_search, - chunk_infos -): - """Emit code with nested activation checkpoint - When we detect some of the node.activation_checkpoint is a List, we will use - this function to emit the activation checkpoint codes. - - Args: - body: forward code - ckpt_func: checkpoint functions code - nodes: graph.nodes - emit_node_func: function to emit node - delete_unused_value_func: function to remove the unused value - """ - node_list = list(nodes) - - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] - - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] - - chunk_outputs = [i["outputs"][0] for i in chunk_infos] - chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - - node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) - node_idx = 0 - region_idx = 0 - within_chunk_region = False - - while node_idx < len(node_list): - node = node_list[node_idx] - - if node_idx in chunk_starts: - within_chunk_region = True - region_idx = chunk_starts.index(node_idx) - body.append( - _gen_loop_start( - chunk_inputs[region_idx], - chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], - chunk_infos[region_idx]["chunk_size"], - ) - ) - - if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - # ones like - if "ones_like" in node.name: - meta_node = chunk_region_search.index_tracer.node_list[node_idx] - chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ - "chunk_dim" - ] - if _get_node_shape(meta_node)[chunk_dim] != 1: - source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node][ - "chunk_dim" - ] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", _get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) - body[-1] = _replace_reshape_size( - body[-1], node.name, chunk_infos[region_idx]["reshape_size"] - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) - else: - emit_node_func(node, body) - if node_idx not in chunk_inputs: - delete_unused_value_func(node, body, chunk_inputs_names) - - if node_idx in chunk_ends: - body.append( - _gen_loop_end( - chunk_inputs[region_idx], - chunk_inputs_non_chunk[region_idx], - chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], - node_list, - ) - ) - within_chunk_region = False - - node_idx += 1 - - -if CODEGEN_AVAILABLE: - - class ChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None): - super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) - # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) - self.chunk_infos = self.chunk_region_search.search_region() - - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return "()" - - typename = _type_repr(o) - - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - if hasattr(o, "__args__"): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename - - return f'{origin_typename}[{",".join(args)}]' - else: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ - return origin_typename - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, "_fields"): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - - args_s = ", ".join(_get_repr(a) for a in args) - kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) - if args_s and kwargs_s: - return f"{args_s}, {kwargs_s}" - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - _delete_free_var_from_last_use(user_to_last_uses) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body, to_keep=[]): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == "placeholder": - return - if user.op == "output": - body.append("\n") - return - nodes_to_delete = user_to_last_uses.get(user, []) - nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] - if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) - body.append(f"; {to_delete_str}\n") - else: - body.append("\n") - - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" - ) - if node.op == "placeholder": - assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) - raw_name = node.target.replace("*", "") - if raw_name != repr(node): - body.append(f"{repr(node)} = {raw_name}\n") - return - elif node.op == "call_method": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) - return - elif node.op == "call_function": - assert callable(node.target) - # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): - assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) - return - body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" - ) - if node.meta.get("is_wrapped", False): - wrapped_fns.setdefault(global_name) - return - elif node.op == "call_module": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" - ) - return - elif node.op == "get_attr": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" - ) - return - elif node.op == "output": - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f"node: {node.op} {node.target}") - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.chunk_region_search, - self.chunk_infos - ) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append("pass\n") - - if len(wrapped_fns) > 0: - wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] - ) - else: - wrap_stmts = "" - - if self._body_transformer: - body = self._body_transformer(body) - - for name, value in self.additional_globals(): - add_global(name, value) - - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = "".join(ckpt_func) + prologue - prologue = prologue - - code = "".join(body) - code = "\n".join(" " + line for line in code.split("\n")) - fn_code = f""" -{wrap_stmts} - -{prologue} -{code}""" - # print(fn_code) - return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py new file mode 100644 index 000000000000..0d0825f2584e --- /dev/null +++ b/colossalai/autochunk/chunk_region_search.py @@ -0,0 +1,211 @@ +from .index_tracer import IndexTracer +from .memory_estiamtor import MemoryEstimator +from .chunk_selector import ChunkSelector +import copy +from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape + + +class ChunkRegionSearch(object): + def __init__(self, gm, max_memory=None) -> None: + self.gm = gm + self.index_tracer = IndexTracer(list(gm.graph.nodes)) + self.index_tracer.trace_index() + self.memory_estimator = MemoryEstimator(self.index_tracer) + self.chunk_selector = ChunkSelector( + self.index_tracer, self.memory_estimator, max_memory=max_memory + ) + + def _find_peak_node(self, mem_peak): + max_value = max(mem_peak) + max_idx = mem_peak.index(max_value) + return max_idx + + def _get_free_var(self): + free_var_idx = [] + for idx, n in enumerate(self.index_tracer.node_list): + if n.op == "placeholder": + free_var_idx.append(idx) + return free_var_idx + + def _get_min_free_var(self, active_node_list, free_vars): + min_len = 999 + for idx, n in enumerate(active_node_list): + if idx in free_vars: + continue + if len(n) < min_len: + min_len = len(n) + return min_len + + def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): + free_vars = self._get_free_var() + free_var_num = len(free_vars) + active_node_num = [len(i) for i in active_node] + min_active_node_num = min(active_node_num[free_var_num:]) + threshold = max(free_var_num, min_active_node_num) + + # from peak_node to free_var + inside_flag = False + chunk_region_start = free_var_num + for i in range(peak_node, -1, -1): + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: + chunk_region_start = i + 1 + break + + # from peak_node to len-2 + inside_flag = False + chunk_region_end = len(active_node) - 1 + for i in range(peak_node, len(active_node)): + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: + chunk_region_end = i + break + + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif ( + region[0] <= chunk_region_start <= region[1] + and chunk_region_end > region[1] + ): + chunk_region_start = region[1] + 1 + elif ( + region[0] <= chunk_region_end <= region[1] + and chunk_region_start < region[0] + ): + chunk_region_end = region[0] - 1 + return chunk_region_start, chunk_region_end + + def _is_not_compute(self, trace, chunk_range, dim_idx): + if trace["idx"][dim_idx] not in trace["compute"]: + return True + if trace["idx"][dim_idx] in trace["compute"] and all( + i < chunk_range[0] or i > chunk_range[1] + for i in trace["compute"][trace["idx"][dim_idx]] + ): + return True + return False + + def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + start_traces = input_trace[start_idx] + end_trace = output_trace[end_idx] + end_node = self.index_tracer.node_list[end_idx] + chunk_infos = [] + for end_dim, _ in enumerate(end_trace["idx"]): + if len(start_traces) > 1: + continue + for start_node, start_trace in start_traces.items(): + for start_dim, _ in enumerate(start_trace["idx"]): + # dim size cannot be 1 + if ( + get_node_shape(end_node)[end_dim] == 1 + or get_node_shape(start_node)[start_dim] == 1 + ): + continue + # check index source align + if not self.index_tracer.check_index_source( + start_dim, start_node, start_idx, end_dim, end_node + ): + continue + # check index copmute + if not self.index_tracer.check_index_compute( + start_idx, end_dim, end_node, end_idx + ): + continue + # flow search + chunk_info = self.index_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim + ) + if chunk_info is None: + continue + # check index copmute + if not self.index_tracer.check_index_duplicate(chunk_info): + continue + chunk_infos.append(chunk_info) + return chunk_infos + + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): + possible_chunk_region = [] + output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + input_trace = [] # trace of a node's input nodes + for _, n in enumerate(self.index_tracer.node_list): + cur_trace = {} + for arg in n.args: + if type(arg) == type(n) and not is_non_compute_node_except_placeholder( + arg + ): + cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + input_trace.append(cur_trace) + + for start_idx in range(max_chunk_region[0], peak_node + 1): + for end_idx in range(peak_node, max_chunk_region[1] + 1): + # skip non compute nodes + if is_non_compute_node( + self.index_tracer.node_list[start_idx] + ) or is_non_compute_node(self.index_tracer.node_list[end_idx]): + continue + + # select free dim + chunk_info = self._find_free_dim( + input_trace, output_trace, start_idx, end_idx + ) + if len(chunk_info) > 0: + possible_chunk_region.extend(chunk_info) + return possible_chunk_region + + def _step_search(self, mem_peak, active_node, chunk_regions): + peak_node = self._find_peak_node(mem_peak) + max_chunk_region = self._search_max_chunk_region( + active_node, peak_node, chunk_regions + ) + if max_chunk_region == None: + return None + possible_chunk_regions = self._search_possible_chunk_regions( + max_chunk_region, peak_node + ) + best_chunk_region = self.chunk_selector._select_best_chunk_region( + possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak + ) + best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) + return best_chunk_region + + def _stop_search(self, init_mem_peak, mem_peak): + sorted_init_mem_peak = sorted(init_mem_peak) + if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: + return True + return False + + def search_region(self): + chunk_infos = [] + ( + init_mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list + ) + mem_peak = init_mem_peak + + while True: + chunk_info = self._step_search(mem_peak, active_node, chunk_infos) + if chunk_info is None: + break + chunk_infos.append(chunk_info) + + ( + mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos + ) + if self._stop_search(init_mem_peak, mem_peak): + break + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) + return chunk_infos + diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py new file mode 100644 index 000000000000..f84322082cc4 --- /dev/null +++ b/colossalai/autochunk/chunk_selector.py @@ -0,0 +1,221 @@ +from .index_tracer import IndexTracer +from .memory_estiamtor import MemoryEstimator +from .utils import is_non_compute_node + + +class ChunkSelector(object): + def __init__( + self, + index_tracer: IndexTracer, + memory_estimator: MemoryEstimator, + max_memory=None, + ): + self.index_tracer = index_tracer + self.memory_estimator = memory_estimator + if max_memory is not None: + self.stratge = "fit_memory" + self.max_memory = max_memory # MB + else: + self.stratge = "min_memory" + + def _select_best_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + if self.stratge == "min_memory": + best_region = self._select_min_memory_chunk_region( + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) + elif self.stratge == "fit_memory": + best_region = self._select_fit_memory_chunk_region( + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) + else: + raise RuntimeError() + return best_region + + def _select_fit_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # stop chunk if max memory satisfy memory limit + if max(mem_peak) < self.max_memory: + return None + + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + if cur_chunk_region_max_peak < self.max_memory: + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + # no region found + if len(regions_dict) == 0: + raise RuntimeError("Search failed. Try a larger memory threshold.") + + # select the min chunk len + chunk_len = [i["chunk_len"] for i in regions_dict] + best_region_idx = chunk_len.index(min(chunk_len)) + best_region = regions_dict[best_region_idx] + + # get max chunk size + best_region = self._get_fit_chunk_size(best_region, chunk_infos) + return best_region + + def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): + chunk_size = 1 + reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_max_mem = 0 + # search a region + while cur_chunk_max_mem < self.max_memory: + chunk_size *= 2 + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [reorder_chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[ + reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + + 1 + ] + ) + # search exact size + chunk_info = chunk_region_dict["chunk_info"] + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos + ) + return chunk_info + + def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): + if l >= 16: + gap = 4 + else: + gap = 1 + chunk_info = chunk_region_dict["reorder_chunk_info"] + while r >= l + gap: + mid = int((l + r) / 2 + 0.5) + chunk_info["chunk_size"] = mid + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + if cur_chunk_max_mem >= self.max_memory: + r = mid - gap + else: + l = mid + gap + return l + + def _get_compute_node_num(self, start, end): + count = 0 + for i in self.index_tracer.node_list[start : end + 1]: + if not is_non_compute_node(i): + count += 1 + return count + + def _select_min_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + + # select the min mem + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] + best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) + best_region = regions_dict[best_region_idx]["chunk_info"] + if best_region is not None: + best_region["chunk_size"] = 1 + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): + return False + return True diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py new file mode 100644 index 000000000000..7a86f3c998fb --- /dev/null +++ b/colossalai/autochunk/index_tracer.py @@ -0,0 +1,1056 @@ +import copy + +from .utils import ( + find_chunk_all_input_nodes, + find_chunk_compute_input_and_output_nodes, + find_idx_by_name, + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) + + +class IndexTracer(object): + def __init__(self, node_list) -> None: + self.node_list = node_list + self.idx_trace_list = self._init_idx_trace_list() + self.idx_trace_equal = [] + self.idx_view_list = {} + self.idx_count = -1 + self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} + + def _init_idx_trace_list(self): + idx_trace_list = [] + for n in self.node_list: + if get_node_shape(n) != None: + cur_trace = { + "idx": [None for _ in range(len(get_node_shape(n)))], + "compute": [[] for _ in range(len(get_node_shape(n)))], + "source": [{} for _ in range(len(get_node_shape(n)))], + } + else: + cur_trace = {"idx": [], "compute": [], "source": []} + idx_trace_list.append(cur_trace) + return idx_trace_list + + def _add_index(self): + """ + Update the count and return it. To record the idx number. + + Returns: + idx_count: int + """ + self.idx_count += 1 + return self.idx_count + + def _del_dim(self, idx, dim_idx): + self.idx_trace_list[idx]["idx"].pop(dim_idx) + self.idx_trace_list[idx]["compute"].pop(dim_idx) + self.idx_trace_list[idx]["source"].pop(dim_idx) + + def _add_dim(self, node_idx, dim_idx): + self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) + self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) + self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) + + def _transform_index(self, node, node_dim): + node_idx = self._find_idx_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] + node_to_trace["compute"][node_to_dim] = copy.deepcopy( + node_from_trace["compute"][node_from_dim] + ) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) + + def _inherit_all_computation(self, node_from, node_to): + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + assert len(node_from_compute) == len(node_to_compute) + for i in range(len(node_from_compute)): + self._add_source(node_from, i, node_to, i) + node_to_compute[i] = copy.deepcopy(node_from_compute[i]) + + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_from_trace_source = self._find_source_trace_from_node(node_from) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_to_trace_source = self._find_source_trace_from_node(node_to) + node_from_idx = find_idx_by_name(node_from.name, self.node_list) + if init: + node_to_trace_source[node_to_dim] = {} + # add dim to cur new source + if node_from_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] + else: + if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: + node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) + # update inputs source + for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): + if node_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) + else: + for d in node_dim: + if d not in node_to_trace_source[node_to_dim][node_idx]: + node_to_trace_source[node_to_dim][node_idx].append(d) + + def _mark_computation_from_node(self, node_from, node_to, exclude=None): + if exclude == None: + exclude = [] + else: + exclude = [self._transform_index(node_to, i) for i in exclude] + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + # assert len(node_from_compute) == len(node_to_compute) + for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): + if self._transform_index(node_to, i) in exclude: + continue + self._add_source(node_from, i, node_to, i) + for j in node_from_compute[i]: + if j not in node_to_compute[i]: + node_to_compute[i].append(j) + + def _mark_idx_equal(self, node1, dim1, node2, dim2): + """ + Mark 2 index to be equal. + + Args: + idx1 (int): index count. + idx2 (int): index count. + """ + # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) + # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) + # if node1_idx > node2_idx: + # self._add_source(node2, dim2, node1, dim1) + # else: + # self._add_source(node1, dim1, node2, dim2) + + def _mark_computation(self, node, idx, dim): + """ + Mark some dims of node as computed. + + Args: + node (node) + idx (int): node index + dim (list or int): dims to be marked as computed + """ + if isinstance(dim, int): + dim = [dim] + dims = list(range(len(get_node_shape(node)))) + for d in dim: + cur_dim = dims[d] + if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: + self.idx_trace_list[idx]["compute"][cur_dim].append(idx) + + def _find_trace_from_node(self, node): + """ + Find node idx and compute trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict + + def _find_source_trace_from_node(self, node): + """ + Find node source trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict["source"] + + def _find_idx_trace_from_node(self, node): + """ + Find node idx trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + """ + node_idx = find_idx_by_name(node.name, self.node_list) + return self.idx_trace_list[node_idx]["idx"] + + def _find_compute_trace_from_node(self, node): + """ + Find node compute trace by the node. + + Args: + node (node) + Returns: + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + return self.idx_trace_list[node_idx]["compute"] + + def _assign_index_as_input(self, node, node_idx, input_node=None): + """ + Assign node's trace as its input node. + + Args: + node (node) + node_idx (int) + """ + if input_node == None: + input_node = node.args[0] + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + self.idx_trace_list[node_idx]["idx"] = new_idx_trace + + self._inherit_all_computation(input_node, node) + + def _assign_all_index(self, node, node_idx): + """ + Add new index for all node's dims. + + Args: + node (node) + node_idx (int) + """ + shape = node.meta["tensor_meta"].shape + new_trace = [] + for _ in shape: + new_trace.append(self._add_index()) + self.idx_trace_list[node_idx]["idx"] = new_trace + + def _assign_transpose_index(self, node, node_idx): + """ + Assign index for transpose op. + 1. swap input's dim according to transpose args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + input_node = node.args[0] + tranpose_dim = node.args[1:] + + self._assign_index_as_input(node, node_idx, input_node) + self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) + self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) + + def _assign_permute_index(self, node, node_idx): + """ + Assign index for permute op. + 1. swap input's dim according to permute args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + permute_dim = node.args[1:] + input_node = node.args[0] + + self._assign_index_as_input(node, node_idx, input_node) + for idx, d in enumerate(permute_dim): + self._inherit_index(input_node, d, node, idx) + + def _assign_linear_index(self, node, node_idx): + """ + Assign index for linear op. + 1. copy trace from input node and change last index accroding to weight + 2. mark equal for input node last index, weight first dim and bias dim. + 3. inherit input's computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + if len(node.args) == 2: + input_node, weight = node.args + bias = None + else: + input_node, weight, bias = node.args + + self._assign_index_as_input(node, node_idx) + self._inherit_index(weight, 1, node, -1) + + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(input_node, -1, weight, 0) + + if bias: + self._mark_idx_equal(input_node, -1, bias, 0) + + def _assign_matmul_index(self, node, node_idx): + """ + Assign index for matmul op. + 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) + 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. + 3. inherit matmul_left and matmul_right computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + matmul_left, matmul_right = node.args + + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_index_as_input(node, node_idx, matmul_left) + self._inherit_index(matmul_right, -1, node, -1) + + self._mark_computation_from_node(matmul_right, node, [-1, -2]) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(matmul_left, -1, matmul_right, -2) + + def _assign_layernorm_index(self, node, idx): + """ + Assign index for layernorm op. + 1. assign index as input node + 2. inherit computation and mark last 2 dims as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + self._mark_computation(node, idx, [-1]) + + def _assign_elementwise_index(self, node, idx): + """ + Assign index for element-wise op (eg. relu sigmoid add mul). + 1. assign index as input node + 2. inherit computation from all input nodes. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + nodes_in = [] + for node_in in node.args: + if type(node_in) == type(node): + nodes_in.append(node_in) + self._mark_computation_from_node(node_in, node) + assert len(nodes_in) <= 2 + if len(nodes_in) == 2: + node_in0_shape = get_node_shape(nodes_in[0]) + node_in1_shape = get_node_shape(nodes_in[1]) + for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): + if node_in0_shape[i] == node_in1_shape[i]: + self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) + + def _assgin_no_change_index(self, node, idx): + self._assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) == type(node): + self._mark_computation_from_node(node_in, node) + + def _assign_einsum_index(self, node, idx): + """ + Assign index for einsum op. + + Args: + node (node) + node_idx (int) + """ + patterns = node.args[0] + input_nodes = node.args[1:] + + patterns = patterns.replace(" ", "") + left, right = patterns.split("->") + left = left.split(",") + + all_index = [] + for i in left: + for c in i: + all_index.append(c) + all_index = set(all_index) + free_index = set([i for i in right]) + sum_index = all_index - free_index + + for right_idx, right_indice in enumerate(right): + for left_idx, left_str in enumerate(left): + if right_indice in left_str: + source_idx = left_str.index(right_indice) + self._inherit_index( + input_nodes[left_idx], source_idx, node, right_idx + ) + + # for i in sum_index: + # for left_idx, left_str in enumerate(left): + # if i in left_str: + # self._mark_computation(node, idx, left_str.index(i)) + # break + + def _assign_softmax_index(self, node, idx): + """ + Assign index for softmax op. + 1. assign index as input node + 2. inherit computation and mark softmax dim as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + self._mark_computation(node, idx, [node.kwargs["dim"]]) + + def _assign_unsqueeze_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._del_dim(node_idx, -1) + self._assign_index_as_input(node, node_idx) + self._add_dim(node_idx, node.args[1]) + + def _assign_dropout_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + + def _assign_ones_like_index(self, node, node_idx): + """ + Assign index for oneslike op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_all_index(node, node_idx) + + def _assign_view_reshape_index(self, node, node_idx): + """ + Assign index for view and reshape op. + 1. get origin shape and target shape by meta info. + 2. compute the real value of -1 in target shape. + 3. determine changed dim, and assgin index for generated dim. + 4. log changed dim and generated dim for restore + 5. inherit computation. + 6. TODO: look into view list to see whether the view is associated with other, + if so assgin equal dim according to previous view. + + Args: + node (node) + node_idx (int) + """ + # get data, turn into number + origin_node = node.args[0] + origin_shape = origin_node.meta["tensor_meta"].shape + target_shape = [] + for i in range(1, len(node.args)): + if isinstance(node.args[i], int): + target_shape.append(node.args[i]) + else: + target_shape.append(node.args[i].meta["fwd_out"][0]) + + # compute the value of -1 + if -1 in target_shape: + origin_product = 1 + for i in origin_shape: + origin_product *= i + target_product = -1 + for i in target_shape: + target_product *= i + shape_idx = target_shape.index(-1) + target_shape[shape_idx] = origin_product // target_product + + # determine changed dim + len_diff = len(origin_shape) - len(target_shape) + if len_diff == 1: + # dim merge + dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] + dim_to = [dim_equal.index(False)] + dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + self._add_dim(node_idx, -1) + elif len_diff == -1: + # dim expand + dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] + dim_from = [dim_equal.index(False)] + dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + self._del_dim(node_idx, -1) + else: + raise NotImplementedError( + "shape" + + str(origin_shape) + + "and" + + str(target_shape) + + "view not implemented" + ) + + # get new index + origin_trace = self._find_idx_trace_from_node(origin_node) + self._assign_index_as_input(node, node_idx, origin_node) + dim_from.reverse() + for i in dim_from: + self._del_dim(node_idx, i) + for i in dim_to: + self._add_dim(node_idx, i) + + # inherit computation + compute_log = self._find_compute_trace_from_node(origin_node) + for i in dim_from: + if origin_trace[i] in compute_log: + for j in dim_to: + self._mark_computation(node, node_idx, [j]) + break + + # log view, not used now + view_dict = { + "idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], + "dim_to": dim_to, + } + self.idx_view_list[node] = view_dict + + def _merge_equal_idx(self): + idx_equal = copy.deepcopy(self.idx_trace_equal) + idx_equal.reverse() + for idx in idx_equal: + merge_to = min(idx) + merge_from = max(idx) + for trace in self.idx_trace_list: + if merge_from in trace["idx"]: + trace["idx"] = [ + merge_to if i == merge_from else i for i in trace["idx"] + ] + + def trace_index(self): + for idx, node in enumerate(self.node_list): + if node.op == "placeholder": + self._assign_all_index(node, idx) + elif node.op == "call_method": + if "transpose" in node.name: + self._assign_transpose_index(node, idx) + elif "permute" in node.name: + self._assign_permute_index(node, idx) + elif "view" in node.name or "reshape" in node.name: + self._assign_view_reshape_index(node, idx) + elif "unsqueeze" in node.name: + self._assign_unsqueeze_index(node, idx) + elif any(i in node.name for i in ["to", "contiguous"]): + self._assgin_no_change_index(node, idx) + else: + raise NotImplementedError(node.name, "method not implemented yet!") + elif node.op == "call_function": + if "linear" in node.name: + self._assign_linear_index(node, idx) + elif "matmul" in node.name: + self._assign_matmul_index(node, idx) + elif "softmax" in node.name: + self._assign_softmax_index(node, idx) + elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): + self._assign_elementwise_index(node, idx) + elif "ones_like" in node.name: + self._assign_ones_like_index(node, idx) + elif "dropout" in node.name: + self._assign_dropout_index(node, idx) + elif "einsum" in node.name: + self._assign_einsum_index(node, idx) + elif "getattr" in node.name: + continue # get attr like shape + elif "getitem" in node.name: + continue # get item in list + else: + raise NotImplementedError( + node.name, "function not implemented yet!" + ) + elif node.op == "call_module": + if any(n in node.name for n in ["layernorm", "norm"]): + self._assign_layernorm_index(node, idx) + else: + raise NotImplementedError(node.name, "module not implemented yet!") + elif node.op == "get_attr": + self._assign_all_index(node, idx) # get param + elif node.op == "output": + continue + else: + raise NotImplementedError(node.op, "op not implemented yet!") + # self._merge_equal_idx() + + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = find_idx_by_name(start_node.name, self.node_list) + end_node_trace = self._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted( + end_node_trace_source.items(), key=lambda d: d[0], reverse=True + ) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and start_dim in node_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self._find_trace_from_node(end_node) + end_node_compute = end_node_trace["compute"][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False + return True + + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = find_idx_by_name(node_to.name, self.node_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] + ): + return node_dim + return None + + def check_index_duplicate(self, chunk_infos, return_dim=False): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim + + for node in self.node_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if is_non_compute_node_except_placeholder(node): + continue + count = 0 + duplicate_dims = [] + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + + if count > 1: + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True + + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): + arg_idx = find_idx_by_name(arg_node.name, self.node_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: + return False + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) + # else add it to list + else: + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + cur_node_list = [self.node_list[end_idx]] # start from the last node + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] + cur_node_idx = find_idx_by_name(cur_node.name, self.node_list) + if cur_node_chunk_dim: + cur_node_compute = self._find_compute_trace_from_node(cur_node) + cur_node_source = self._find_source_trace_from_node(cur_node) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not ( + start_idx + <= find_idx_by_name(arg.name, self.node_list) + < end_idx + ): + continue + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] + arg_shape = get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + for user in input_node.users.keys(): + if is_non_compute_node(user): + continue + user_idx = find_idx_by_name(user.name, self.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]["chunk_dim"] + if chunk_dim is not None: + user_source = self._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info["chunk_dim"] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.node_list), + reverse=True, + ) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not ( + start_idx + <= find_idx_by_name( + cur_prepose_node_arg.name, self.node_list + ) + < end_idx + ): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes + + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in prepose_nodes: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info["region"] + reshape_size = {} + chunk_shape = get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] + for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): + reshape_args = node.args[1:] + reshape_log = self.idx_view_list[node] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log["dim_to"]: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) + chunk_info["reshape_size"] = reshape_size + return chunk_info + + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = find_idx_by_name(n.name, self.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + new_inputs_dim = [] + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.node_list[old_idx] + self.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] + self.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info diff --git a/colossalai/autochunk/memory_estiamtor.py b/colossalai/autochunk/memory_estiamtor.py new file mode 100644 index 000000000000..c3d8b1803ce9 --- /dev/null +++ b/colossalai/autochunk/memory_estiamtor.py @@ -0,0 +1,318 @@ +import copy +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch +from torch.fx.node import Node, map_arg + +from colossalai.fx.profiler import activation_size, parameter_size + +from .index_tracer import IndexTracer +from .utils import ( + delete_free_var_from_last_use, + find_idx_by_name, + get_node_shape, + is_non_compute_node_except_placeholder, +) + + +class MemoryEstimator(object): + def __init__(self, index_tracer: IndexTracer) -> None: + pass + + def _get_meta_node_size(self, x): + x = x.meta["tensor_meta"] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x + + def _get_output_node(self, n): + fwd_out = { + x.uuid: x + for x in n.meta["fwd_out"] + if isinstance(x, torch.Tensor) and hasattr(x, "uuid") + } + out_size = activation_size(fwd_out) + out_node = [n.name] if out_size > 0 else [] + # if any(i in n.name for i in ['transpose', 'permute', 'view']): + # out_size = 0 + return out_size, out_node + + def _get_output_node_size(self, n): + return self._get_output_node(n)[0] + + def _add_active_node(self, n, active_list): + new_active = self._get_output_node(n)[1] + if n.op == "placeholder": + new_active.append(n.name) + for i in new_active: + if i not in active_list: + active_list.append(i) + + def _get_delete_node(self, user, user_to_last_uses, to_keep=None): + delete_size = 0 + delete_node = [] + if user.op not in ("output",): + nodes_to_delete = user_to_last_uses.get(user, []) + if to_keep is not None: + keep_list = [] + for n in nodes_to_delete: + if n.name in to_keep: + keep_list.append(n) + for n in keep_list: + if n in nodes_to_delete: + nodes_to_delete.remove(n) + if len(nodes_to_delete): + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + for i in range(len(out_node)): + if out_node[i][0] > 0: + delete_node.append(out_node[i][1][0]) + elif nodes_to_delete[i].op == "placeholder": + delete_node.append(nodes_to_delete[i].name) + # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): + # delete_node.append(nodes_to_delete[i].name) + return delete_size, delete_node + + def _get_delete_node_size(self, user, user_to_last_uses, to_keep): + return self._get_delete_node(user, user_to_last_uses, to_keep)[0] + + def _remove_deactive_node(self, user, user_to_last_uses, active_list): + delete_node = self._get_delete_node(user, user_to_last_uses)[1] + for i in delete_node: + if i in active_list: + active_list.remove(i) + + def _get_chunk_inputs_size( + self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx + ): + nodes_to_delete = [] + for chunk_input in chunk_inputs + chunk_inputs_non_chunk: + chunk_input_users = chunk_input.users.keys() + chunk_input_users_idx = [ + find_idx_by_name(i.name, node_list) for i in chunk_input_users + ] + if all(i <= chunk_end_idx for i in chunk_input_users_idx): + if chunk_input not in nodes_to_delete: + nodes_to_delete.append(chunk_input) + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + return delete_size + + def _get_last_usr(self, nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses + + def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ["permute"] + inherit_contiguous_ops = ["transpose", "view"] + + if node.op == "call_function" and any( + n in node.name for n in ["matmul", "reshape"] + ): + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += self._get_output_node_size(n) + elif node.op == "call_module": + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == "call_method" and any( + i in node.name for i in not_contiguous_ops + ): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + return mem + + def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): + if node not in chunk_node_dim: + return 1.0 + node_shape = get_node_shape(node) + chunk_dim = chunk_node_dim[node]["chunk_dim"] + if chunk_dim is None: + return 1.0 + else: + return float(chunk_size) / node_shape[chunk_dim] + + def _get_chunk_delete_node_size( + self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names + ): + # if any(j in user.name for j in ['transpose', 'permute', 'view']): + # return 0 + if user.op in ("placeholder", "output"): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + if n.name in chunk_inputs_names: + continue + delete_size += self._get_output_node_size(n) * chunk_ratio + return delete_size + + def _print_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end="") + if (idx + 1) % 3 == 0: + print("") + print("\n") + + def _print_compute_op_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + if n.op in ["placeholder", "get_attr", "output"]: + continue + if any(i in n.name for i in ["getitem", "getattr"]): + continue + print("%s:%.2f \t" % (n.name, l), end="") + if (idx + 1) % 3 == 0: + print("") + print("\n") + + def estimate_chunk_inference_mem( + self, + node_list, + chunk_infos=None, + print_mem=False, + ): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + active_node_list = [] + active_node_list_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + delete_free_var_from_last_use(user_to_last_uses_no_free_var) + + use_chunk = True if chunk_infos is not None else False + chunk_within = False + chunk_region_idx = None + chunk_ratio = 1 # use it to estimate chunk mem + chunk_inputs_names = [] + + if use_chunk: + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] + chunk_sizes = [ + i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos + ] + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if use_chunk and idx in chunk_starts: + chunk_within = True + chunk_region_idx = chunk_starts.index(idx) + act_memory += self._get_output_node_size( + chunk_outputs[chunk_region_idx] + ) / (1024**2) + + # determine chunk ratio for current node + if chunk_within: + chunk_ratio = self._get_chunk_ratio( + node, + chunk_node_dim[chunk_region_idx], + chunk_sizes[chunk_region_idx], + ) + + # if node is placeholder, just add the size of the node + if node.op == "placeholder": + act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == "output": + continue + # no change for non compute node + elif is_non_compute_node_except_placeholder(node): + act_memory_peak_log.append(act_memory) + # node is a compute op + # calculate tmp, output node and delete node memory + else: + # forward memory + # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose + act_memory += ( + self._get_contiguous_memory(node, not_contiguous_list) + * chunk_ratio + / (1024**2) + ) + act_memory += ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= ( + self._get_contiguous_memory(node, not_contiguous_list, delete=True) + * chunk_ratio + / (1024**2) + ) + # delete unused vars not in chunk_input_list + # we can't delete input nodes until chunk ends + if chunk_within: + act_memory -= self._get_chunk_delete_node_size( + node, + user_to_last_uses_no_free_var, + chunk_ratio, + chunk_inputs_names, + ) / (1024**2) + else: + act_memory -= self._get_delete_node_size( + node, user_to_last_uses_no_free_var, chunk_inputs_names + ) / (1024**2) + + # log active node, only effective without chunk + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + + # if node in chunk end nodes, restore chunk settings + if use_chunk and idx in chunk_ends: + act_memory -= ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) + act_memory -= self._get_chunk_inputs_size( + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], + node_list, + chunk_regions[chunk_region_idx][1], + ) / (1024**2) + chunk_within = False + chunk_ratio = 1 + chunk_region_idx = None + + act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) + + if print_mem: + print("with chunk" if use_chunk else "without chunk") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_compute_op_mem_log( + # act_memory_after_node_log, node_list, "after" + # ) + + # param_memory = parameter_size(gm) + # all_memory = act_memory + param_memory + return act_memory_peak_log, act_memory_after_node_log, active_node_list_log diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py new file mode 100644 index 000000000000..b62a6600adc8 --- /dev/null +++ b/colossalai/autochunk/utils.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +from torch.fx.node import Node + + +def is_non_compute_node(node): + if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def get_node_shape(node): + if hasattr(node.meta["tensor_meta"], "shape"): + return node.meta["tensor_meta"].shape + return None + + +def is_non_compute_node_except_placeholder(node): + if any(i in node.op for i in ["get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def is_non_compute_node_except_placeholder_output(node): + if any(i in node.op for i in ["get_attr"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def find_idx_by_name(name, nodes_list): + for idx, node in enumerate(nodes_list): + if node.name == name: + return idx + raise RuntimeError("name %s not found in node list" % name) + + +def delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == "placeholder": + user_to_last_uses[key].remove(n) + + +def find_chunk_all_input_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + for node in nodes: + for input_node in node._input_nodes.keys(): + if input_node not in nodes and input_node not in input_nodes: + input_nodes.append(input_node) + return input_nodes + + +def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + if ( + input_node not in nodes + and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node) + ): + input_nodes.append(input_node) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + if ( + output_node not in nodes + and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node) + ): + output_nodes.append(node) + + return input_nodes, output_nodes diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 8df6d9ff4564..702eb7026bb7 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -3,7 +3,7 @@ import torch import torch.fx -from colossalai.autochunk.chunk_codegen import ChunkCodeGen +from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -49,25 +49,29 @@ def _build_autochunk(model, max_memory, node, pair): "pair": pair.to(torch.device("meta")), }, ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) interp.propagate( MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) + # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) interp.propagate( MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) + # set code_gen - codegen = ChunkCodeGen(gm_prop, max_memory) + codegen = AutoChunkCodeGen(gm_prop, max_memory) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() + # print - code = graph.python_code("self").src - print(code) + # code = graph.python_code("self").src + # print(code) return gm diff --git a/tests/test_autochunk/test_autochunk.py b/tests/test_autochunk/test_autochunk.py index caa2d9a80254..85a162084cc9 100644 --- a/tests/test_autochunk/test_autochunk.py +++ b/tests/test_autochunk/test_autochunk.py @@ -4,7 +4,7 @@ import torch.multiprocessing as mp import colossalai -from colossalai.autochunk.chunk_codegen import ChunkCodeGen +from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule @@ -82,7 +82,7 @@ def _run_offload_codegen(rank): MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) - codegen = ChunkCodeGen(gm_prop) + codegen = AutoChunkCodeGen(gm_prop) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() From 8a634af2f5510954e7a992c0ee894d22cf9e26d2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:19:45 +0800 Subject: [PATCH 087/209] close mem and code print --- colossalai/autochunk/autochunk_codegen.py | 4 ++-- colossalai/autochunk/chunk_region_search.py | 11 +++++++---- tests/test_autochunk/benchmark_autochunk.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 58a8c375136e..dcc6bba9ed0a 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -214,13 +214,13 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None): + def __init__(self, meta_graph, max_memory=None, print_mem=False): super().__init__() self.meta_graph = meta_graph self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) + self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem) self.chunk_infos = self.chunk_region_search.search_region() def _gen_python_code( diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 0d0825f2584e..76b02cadeb3b 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -6,8 +6,9 @@ class ChunkRegionSearch(object): - def __init__(self, gm, max_memory=None) -> None: + def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm + self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) @@ -204,8 +205,10 @@ def search_region(self): ) if self._stop_search(init_mem_peak, mem_peak): break - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True - ) + if self.print_mem: + self.print_mem = False + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) return chunk_infos diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 702eb7026bb7..9daaa364a710 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -64,7 +64,7 @@ def _build_autochunk(model, max_memory, node, pair): ) # set code_gen - codegen = AutoChunkCodeGen(gm_prop, max_memory) + codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() From 2bde9d2b7fd43f3160088b820d926301f6527ebf Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:21:49 +0800 Subject: [PATCH 088/209] code format --- colossalai/autochunk/autochunk_codegen.py | 4 +++- colossalai/autochunk/chunk_region_search.py | 14 +++++++++----- colossalai/autochunk/memory_estiamtor.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index dcc6bba9ed0a..fbd5d5e368dc 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -220,7 +220,9 @@ def __init__(self, meta_graph, max_memory=None, print_mem=False): self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem) + self.chunk_region_search = ChunkRegionSearch( + meta_graph, max_memory, print_mem + ) self.chunk_infos = self.chunk_region_search.search_region() def _gen_python_code( diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 76b02cadeb3b..7a0e8a36cd6c 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -1,8 +1,13 @@ +import copy + +from .chunk_selector import ChunkSelector from .index_tracer import IndexTracer from .memory_estiamtor import MemoryEstimator -from .chunk_selector import ChunkSelector -import copy -from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape +from .utils import ( + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) class ChunkRegionSearch(object): @@ -11,7 +16,7 @@ def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() - self.memory_estimator = MemoryEstimator(self.index_tracer) + self.memory_estimator = MemoryEstimator() self.chunk_selector = ChunkSelector( self.index_tracer, self.memory_estimator, max_memory=max_memory ) @@ -211,4 +216,3 @@ def search_region(self): self.index_tracer.node_list, chunk_infos, print_mem=True ) return chunk_infos - diff --git a/colossalai/autochunk/memory_estiamtor.py b/colossalai/autochunk/memory_estiamtor.py index c3d8b1803ce9..034f59e52858 100644 --- a/colossalai/autochunk/memory_estiamtor.py +++ b/colossalai/autochunk/memory_estiamtor.py @@ -16,7 +16,7 @@ class MemoryEstimator(object): - def __init__(self, index_tracer: IndexTracer) -> None: + def __init__(self) -> None: pass def _get_meta_node_size(self, x): From fd87d78a28a70fcb840c16d4084f67926ecc309c Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:28:04 +0800 Subject: [PATCH 089/209] rename ambiguous variable --- colossalai/autochunk/chunk_selector.py | 14 +++++++------- tests/test_autochunk/evoformer/ops.py | 6 +++--- tests/test_autochunk/openfold/tensor_utils.py | 8 ++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py index f84322082cc4..aeab66572099 100644 --- a/colossalai/autochunk/chunk_selector.py +++ b/colossalai/autochunk/chunk_selector.py @@ -126,14 +126,14 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): ) return chunk_info - def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): - if l >= 16: + def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): + if left >= 16: gap = 4 else: gap = 1 chunk_info = chunk_region_dict["reorder_chunk_info"] - while r >= l + gap: - mid = int((l + r) / 2 + 0.5) + while right >= left + gap: + mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( @@ -143,10 +143,10 @@ def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] ) if cur_chunk_max_mem >= self.max_memory: - r = mid - gap + right = mid - gap else: - l = mid + gap - return l + left = mid + gap + return left def _get_compute_node_num(self, start, end): count = 0 diff --git a/tests/test_autochunk/evoformer/ops.py b/tests/test_autochunk/evoformer/ops.py index 611b7b0fe777..a56057522eaa 100755 --- a/tests/test_autochunk/evoformer/ops.py +++ b/tests/test_autochunk/evoformer/ops.py @@ -67,10 +67,10 @@ def forward(self, M): left_act = self.linear_a(M) right_act = self.linear_b(M) - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) + o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1) + Z = self.o_linear(o) return Z diff --git a/tests/test_autochunk/openfold/tensor_utils.py b/tests/test_autochunk/openfold/tensor_utils.py index 7e5e8e4b6b5e..384a71fb5ffd 100644 --- a/tests/test_autochunk/openfold/tensor_utils.py +++ b/tests/test_autochunk/openfold/tensor_utils.py @@ -157,12 +157,12 @@ def _get_minimal_slice_set( # start_edges and end_edges both indicate whether, starting from any given # dimension, the start/end index is at the top/bottom edge of the # corresponding tensor, modeled as a tree - def reduce_edge_list(l): + def reduce_edge_list(ll): tally = 1 - for i in range(len(l)): + for i in range(len(ll)): reversed_idx = -1 * (i + 1) - l[reversed_idx] *= tally - tally = l[reversed_idx] + ll[reversed_idx] *= tally + tally = ll[reversed_idx] if(start_edges is None): start_edges = [s == 0 for s in start] From ae27a8b26d7a36a3d9215fc6fd1db92982bdeef7 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:57:33 +0800 Subject: [PATCH 090/209] seperate flow tracer --- colossalai/autochunk/index_tracer.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 7a86f3c998fb..0323e3a7e07d 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -745,14 +745,7 @@ def _assgin_single_node_flow( next_node_list.append(arg_node) return True - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - + def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_list = [self.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} @@ -763,7 +756,6 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: cur_node_compute = self._find_compute_trace_from_node(cur_node) cur_node_source = self._find_source_trace_from_node(cur_node) @@ -818,6 +810,20 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): else: raise NotImplementedError() cur_node_list = next_node_list + return all_node_info + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None inputs_dim = [] remove_inputs = [] From f4a1607e5645e3a537df6e88b67fb57a8fc6ed4f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 15:36:17 +0800 Subject: [PATCH 091/209] seperate input node dim search --- colossalai/autochunk/index_tracer.py | 35 +++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 0323e3a7e07d..221217e2d101 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -812,19 +812,7 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_list = next_node_list return all_node_info - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - + def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): inputs_dim = [] remove_inputs = [] for input_node in inputs: @@ -841,7 +829,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: - return None + return None, None if len(input_dict) == 0: remove_inputs.append(input_node) else: @@ -849,6 +837,25 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): for i in remove_inputs: if i in inputs: inputs.remove(i) + return inputs, inputs_dim + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None chunk_info = { "region": (start_idx, end_idx), From f856611d217e13c11ea382fe9d8f8af4cdeabb49 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 15:47:17 +0800 Subject: [PATCH 092/209] seperate prepose_nodes --- colossalai/autochunk/index_tracer.py | 68 +++++++++++++++------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 221217e2d101..206d2edbd5df 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -839,36 +839,7 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): inputs.remove(i) return inputs, inputs_dim - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - - # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) - if inputs is None: - return None - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop + def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): @@ -929,12 +900,45 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) - chunk_info["args"]["prepose_nodes"] = prepose_nodes + + return prepose_nodes + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) # we need to log input nodes to avoid deleteing them in the loop chunk_node_list = self.node_list[start_idx : end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs - for n in prepose_nodes: + for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: From 6685a9d022a912ab3d0a57486b045b92b3f681ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 15:53:24 +0800 Subject: [PATCH 093/209] seperate non chunk input --- colossalai/autochunk/index_tracer.py | 35 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 206d2edbd5df..202044763b0f 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -839,7 +839,7 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): inputs.remove(i) return inputs, inputs_dim - def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): + def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): @@ -902,7 +902,19 @@ def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) return prepose_nodes - + + def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in chunk_info["args"]["prepose_nodes"]: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + return chunk_info + def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] @@ -917,7 +929,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): return None # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + inputs, inputs_dim = self._get_input_nodes_dim( + inputs, start_idx, end_idx, all_node_info + ) if inputs is None: return None @@ -933,17 +947,12 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): } # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( + all_node_info, start_idx, end_idx + ) - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in chunk_info["args"]["prepose_nodes"]: - chunk_node_list.remove(n) - non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) + # find non chunk inputs + chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) # reassgin reshape size, some size may have changed due to chunk chunk_info = self._reassgin_reshape_size(chunk_info) From c3d72f7db9e2fc28e9a3aa92749f08c7a7d51e42 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 16:53:01 +0800 Subject: [PATCH 094/209] seperate reorder --- colossalai/autochunk/autochunk_codegen.py | 4 +-- colossalai/autochunk/chunk_region_search.py | 7 +++-- colossalai/autochunk/chunk_selector.py | 8 ++++-- colossalai/autochunk/index_tracer.py | 31 ++++++++++++--------- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index fbd5d5e368dc..b4144196accc 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -103,7 +103,7 @@ def emit_code_with_chunk( nodes, emit_node_func, delete_unused_value_func, - chunk_region_search, + chunk_region_search: ChunkRegionSearch, chunk_infos, ): """Emit code with nested activation checkpoint @@ -133,7 +133,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) + node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 7a0e8a36cd6c..47e2fe13ceb5 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -1,7 +1,7 @@ import copy from .chunk_selector import ChunkSelector -from .index_tracer import IndexTracer +from .index_tracer import IndexTracer, ReorderGraph from .memory_estiamtor import MemoryEstimator from .utils import ( get_node_shape, @@ -16,9 +16,10 @@ def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() + self.reorder_graph = ReorderGraph(self.index_tracer) self.memory_estimator = MemoryEstimator() self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, max_memory=max_memory + self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -175,7 +176,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): best_chunk_region = self.chunk_selector._select_best_chunk_region( possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) - best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) + best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py index aeab66572099..119ff8aafdd0 100644 --- a/colossalai/autochunk/chunk_selector.py +++ b/colossalai/autochunk/chunk_selector.py @@ -1,4 +1,4 @@ -from .index_tracer import IndexTracer +from .index_tracer import IndexTracer, ReorderGraph from .memory_estiamtor import MemoryEstimator from .utils import is_non_compute_node @@ -8,10 +8,12 @@ def __init__( self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, + reorder_graph: ReorderGraph, max_memory=None, ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator + self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" self.max_memory = max_memory # MB @@ -64,7 +66,7 @@ def _select_fit_memory_chunk_region( regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( + cur_node_list, cur_region = self.reorder_graph.tmp_reorder( self.index_tracer.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] @@ -174,7 +176,7 @@ def _select_min_memory_chunk_region( regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( + cur_node_list, cur_region = self.reorder_graph.tmp_reorder( self.index_tracer.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 202044763b0f..8b4d3aabd13a 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -17,7 +17,6 @@ def __init__(self, node_list) -> None: self.idx_trace_equal = [] self.idx_view_list = {} self.idx_count = -1 - self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} def _init_idx_trace_list(self): idx_trace_list = [] @@ -981,24 +980,30 @@ def _reassgin_reshape_size(self, chunk_info): chunk_info["reshape_size"] = reshape_size return chunk_info + +class ReorderGraph(object): + def __init__(self, index_tracer: IndexTracer) -> None: + self.index_tracer = index_tracer + self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.node_list))} + reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes ] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.node_list) + n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -1024,25 +1029,25 @@ def _update_all_reorder_map(self, reorder_map): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.node_list))] + new_node_list = [None for _ in range(len(self.index_tracer.node_list))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.node_list[old_idx] - self.node_list = new_node_list + new_node_list[new_idx] = self.index_tracer.node_list[old_idx] + self.index_tracer.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] - self.idx_trace_list = new_idx_trace_list + new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] + self.index_tracer.idx_trace_list = new_idx_trace_list # update compute - for idx_trace in self.idx_trace_list: + for idx_trace in self.index_tracer.idx_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source - for idx_trace in self.idx_trace_list: + for idx_trace in self.index_tracer.idx_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {} From da4076846d693be0153c8e89ee48ce25f56d09ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:09:37 +0800 Subject: [PATCH 095/209] rename --- colossalai/autochunk/autochunk_codegen.py | 6 +++--- .../{memory_estiamtor.py => estiamte_memory.py} | 3 +-- .../{chunk_region_search.py => search_chunk.py} | 14 +++++++------- .../{chunk_selector.py => select_chunk.py} | 10 +++++----- .../autochunk/{index_tracer.py => trace_index.py} | 4 ++-- tests/test_autochunk/benchmark_autochunk.py | 2 +- 6 files changed, 19 insertions(+), 20 deletions(-) rename colossalai/autochunk/{memory_estiamtor.py => estiamte_memory.py} (99%) rename colossalai/autochunk/{chunk_region_search.py => search_chunk.py} (96%) rename colossalai/autochunk/{chunk_selector.py => select_chunk.py} (97%) rename colossalai/autochunk/{index_tracer.py => trace_index.py} (99%) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index b4144196accc..3bb2e83be242 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -17,7 +17,7 @@ import colossalai -from .chunk_region_search import ChunkRegionSearch +from .search_chunk import SearchChunk from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape CODEGEN_AVAILABLE = True @@ -103,7 +103,7 @@ def emit_code_with_chunk( nodes, emit_node_func, delete_unused_value_func, - chunk_region_search: ChunkRegionSearch, + chunk_region_search: SearchChunk, chunk_infos, ): """Emit code with nested activation checkpoint @@ -220,7 +220,7 @@ def __init__(self, meta_graph, max_memory=None, print_mem=False): self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = ChunkRegionSearch( + self.chunk_region_search = SearchChunk( meta_graph, max_memory, print_mem ) self.chunk_infos = self.chunk_region_search.search_region() diff --git a/colossalai/autochunk/memory_estiamtor.py b/colossalai/autochunk/estiamte_memory.py similarity index 99% rename from colossalai/autochunk/memory_estiamtor.py rename to colossalai/autochunk/estiamte_memory.py index 034f59e52858..90cfd66a00d5 100644 --- a/colossalai/autochunk/memory_estiamtor.py +++ b/colossalai/autochunk/estiamte_memory.py @@ -6,7 +6,6 @@ from colossalai.fx.profiler import activation_size, parameter_size -from .index_tracer import IndexTracer from .utils import ( delete_free_var_from_last_use, find_idx_by_name, @@ -15,7 +14,7 @@ ) -class MemoryEstimator(object): +class EstimateMemory(object): def __init__(self) -> None: pass diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/search_chunk.py similarity index 96% rename from colossalai/autochunk/chunk_region_search.py rename to colossalai/autochunk/search_chunk.py index 47e2fe13ceb5..5c58bda0c393 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,8 +1,8 @@ import copy -from .chunk_selector import ChunkSelector -from .index_tracer import IndexTracer, ReorderGraph -from .memory_estiamtor import MemoryEstimator +from .select_chunk import SelectChunk +from .trace_index import TraceIndex, ReorderGraph +from .estiamte_memory import EstimateMemory from .utils import ( get_node_shape, is_non_compute_node, @@ -10,15 +10,15 @@ ) -class ChunkRegionSearch(object): +class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem - self.index_tracer = IndexTracer(list(gm.graph.nodes)) + self.index_tracer = TraceIndex(list(gm.graph.nodes)) self.index_tracer.trace_index() self.reorder_graph = ReorderGraph(self.index_tracer) - self.memory_estimator = MemoryEstimator() - self.chunk_selector = ChunkSelector( + self.memory_estimator = EstimateMemory() + self.chunk_selector = SelectChunk( self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory ) diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/select_chunk.py similarity index 97% rename from colossalai/autochunk/chunk_selector.py rename to colossalai/autochunk/select_chunk.py index 119ff8aafdd0..f0262f1e57eb 100644 --- a/colossalai/autochunk/chunk_selector.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,13 +1,13 @@ -from .index_tracer import IndexTracer, ReorderGraph -from .memory_estiamtor import MemoryEstimator +from .trace_index import TraceIndex, ReorderGraph +from .estiamte_memory import EstimateMemory from .utils import is_non_compute_node -class ChunkSelector(object): +class SelectChunk(object): def __init__( self, - index_tracer: IndexTracer, - memory_estimator: MemoryEstimator, + index_tracer: TraceIndex, + memory_estimator: EstimateMemory, reorder_graph: ReorderGraph, max_memory=None, ): diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/trace_index.py similarity index 99% rename from colossalai/autochunk/index_tracer.py rename to colossalai/autochunk/trace_index.py index 8b4d3aabd13a..103a05dadbf5 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/trace_index.py @@ -10,7 +10,7 @@ ) -class IndexTracer(object): +class TraceIndex(object): def __init__(self, node_list) -> None: self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list() @@ -982,7 +982,7 @@ def _reassgin_reshape_size(self, chunk_info): class ReorderGraph(object): - def __init__(self, index_tracer: IndexTracer) -> None: + def __init__(self, index_tracer: TraceIndex) -> None: self.index_tracer = index_tracer self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 9daaa364a710..081f01368a42 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -104,7 +104,7 @@ def benchmark_evoformer(): model = evoformer_base().cuda() # build autochunk model - # max_memory = 10000 # MB fit memory mode + # max_memory = 1000 # MB fit memory mode max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) From 4748967fb12747043c6688b3f13190203ade769f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:13:18 +0800 Subject: [PATCH 096/209] ad reorder graph --- colossalai/autochunk/reorder_graph.py | 108 ++++++++++++++++++++++++++ colossalai/autochunk/trace_index.py | 106 ------------------------- 2 files changed, 108 insertions(+), 106 deletions(-) create mode 100644 colossalai/autochunk/reorder_graph.py diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py new file mode 100644 index 000000000000..7b9f4a20d6ab --- /dev/null +++ b/colossalai/autochunk/reorder_graph.py @@ -0,0 +1,108 @@ +from .trace_index import TraceIndex +from .utils import find_idx_by_name + + +class ReorderGraph(object): + def __init__(self, index_tracer: TraceIndex) -> None: + self.index_tracer = index_tracer + self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + new_inputs_dim = [] + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.index_tracer.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.index_tracer.node_list[old_idx] + self.index_tracer.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] + self.index_tracer.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.index_tracer.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.index_tracer.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info diff --git a/colossalai/autochunk/trace_index.py b/colossalai/autochunk/trace_index.py index 103a05dadbf5..3ac0d7f84272 100644 --- a/colossalai/autochunk/trace_index.py +++ b/colossalai/autochunk/trace_index.py @@ -979,109 +979,3 @@ def _reassgin_reshape_size(self, chunk_info): ) chunk_info["reshape_size"] = reshape_size return chunk_info - - -class ReorderGraph(object): - def __init__(self, index_tracer: TraceIndex) -> None: - self.index_tracer = index_tracer - self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} - - def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} - - chunk_region_start = chunk_info["region"][0] - chunk_region_end = chunk_info["region"][1] - chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] - chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes - ] - # put prepose nodes ahead - for idx, n in enumerate(chunk_prepose_nodes): - n_idx = chunk_prepose_nodes_idx[idx] - reorder_map[n_idx] = chunk_region_start + idx - # put other nodes after prepose nodes - for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: - if n in chunk_prepose_nodes: - continue - n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) - pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) - reorder_map[n_idx] = n_idx + pos - - return reorder_map - - def _reorder_chunk_info(self, chunk_info, reorder_map): - # update chunk info - chunk_info["region"] = ( - chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), - chunk_info["region"][1], - ) - new_inputs_dim = [] - for idx, input_dim in enumerate(chunk_info["inputs_dim"]): - new_input_dim = {} - for k, v in input_dim.items(): - new_input_dim[reorder_map[k]] = v - new_inputs_dim.append(new_input_dim) - chunk_info["inputs_dim"] = new_inputs_dim - return chunk_info - - def _update_all_reorder_map(self, reorder_map): - for origin_idx, map_idx in self.all_reorder_map.items(): - self.all_reorder_map[origin_idx] = reorder_map[map_idx] - - def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.index_tracer.node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.index_tracer.node_list[old_idx] - self.index_tracer.node_list = new_node_list - - def _reorder_idx_trace(self, reorder_map): - # reorder list - new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] - for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] - self.index_tracer.idx_trace_list = new_idx_trace_list - # update compute - for idx_trace in self.index_tracer.idx_trace_list: - compute = idx_trace["compute"] - for dim_compute in compute: - for idx, i in enumerate(dim_compute): - dim_compute[idx] = reorder_map[i] - # update source - for idx_trace in self.index_tracer.idx_trace_list: - source = idx_trace["source"] - for dim_idx, dim_source in enumerate(source): - new_dim_source = {} - for k, v in dim_source.items(): - new_dim_source[reorder_map[k]] = v - source[dim_idx] = new_dim_source - - def reorder_all(self, chunk_info): - if chunk_info is None: - return chunk_info - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return chunk_info - reorder_map = self._get_reorder_map(chunk_info) - self._update_all_reorder_map(reorder_map) - self._reorder_idx_trace(reorder_map) - self._reorder_self_node_list(reorder_map) - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return chunk_info - - def reorder_node_list(self, node_list): - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in self.all_reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - return new_node_list - - def tmp_reorder(self, node_list, chunk_info): - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return node_list, chunk_info - reorder_map = self._get_reorder_map(chunk_info) - - # new tmp node list - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return new_node_list, chunk_info From a6cdbf9161afc526d3a961708c0b202ca18c3e7e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:24:23 +0800 Subject: [PATCH 097/209] seperate trace flow --- colossalai/autochunk/autochunk_codegen.py | 2 +- colossalai/autochunk/search_chunk.py | 53 +-- colossalai/autochunk/select_chunk.py | 3 +- colossalai/autochunk/trace_flow.py | 414 ++++++++++++++++++++ colossalai/autochunk/trace_index.py | 395 ------------------- tests/test_autochunk/benchmark_autochunk.py | 4 +- 6 files changed, 447 insertions(+), 424 deletions(-) create mode 100644 colossalai/autochunk/trace_flow.py diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 3bb2e83be242..39728cb794f7 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -167,7 +167,7 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - meta_node = chunk_region_search.index_tracer.node_list[node_idx] + meta_node = chunk_region_search.trace_index.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ "chunk_dim" ] diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 5c58bda0c393..030b13bdb9c4 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,8 +1,10 @@ import copy from .select_chunk import SelectChunk -from .trace_index import TraceIndex, ReorderGraph +from .trace_index import TraceIndex +from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .trace_flow import TraceFlow from .utils import ( get_node_shape, is_non_compute_node, @@ -14,12 +16,13 @@ class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem - self.index_tracer = TraceIndex(list(gm.graph.nodes)) - self.index_tracer.trace_index() - self.reorder_graph = ReorderGraph(self.index_tracer) - self.memory_estimator = EstimateMemory() - self.chunk_selector = SelectChunk( - self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory + self.trace_index = TraceIndex(list(gm.graph.nodes)) + self.trace_index.trace_index() + self.trace_flow = TraceFlow(self.trace_index) + self.reorder_graph = ReorderGraph(self.trace_index) + self.estimate_memory = EstimateMemory() + self.select_chunk = SelectChunk( + self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -29,7 +32,7 @@ def _find_peak_node(self, mem_peak): def _get_free_var(self): free_var_idx = [] - for idx, n in enumerate(self.index_tracer.node_list): + for idx, n in enumerate(self.trace_index.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx @@ -99,7 +102,7 @@ def _is_not_compute(self, trace, chunk_range, dim_idx): def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] - end_node = self.index_tracer.node_list[end_idx] + end_node = self.trace_index.node_list[end_idx] chunk_infos = [] for end_dim, _ in enumerate(end_trace["idx"]): if len(start_traces) > 1: @@ -113,46 +116,46 @@ def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): ): continue # check index source align - if not self.index_tracer.check_index_source( + if not self.trace_flow.check_index_source( start_dim, start_node, start_idx, end_dim, end_node ): continue # check index copmute - if not self.index_tracer.check_index_compute( + if not self.trace_flow.check_index_compute( start_idx, end_dim, end_node, end_idx ): continue # flow search - chunk_info = self.index_tracer.flow_search( + chunk_info = self.trace_flow.flow_search( start_idx, start_dim, end_idx, end_dim ) if chunk_info is None: continue # check index copmute - if not self.index_tracer.check_index_duplicate(chunk_info): + if not self.trace_flow.check_index_duplicate(chunk_info): continue chunk_infos.append(chunk_info) return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] - output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + output_trace = copy.deepcopy(self.trace_index.idx_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.index_tracer.node_list): + for _, n in enumerate(self.trace_index.node_list): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not is_non_compute_node_except_placeholder( arg ): - cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + cur_trace[arg] = self.trace_index._find_trace_from_node(arg) input_trace.append(cur_trace) for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node( - self.index_tracer.node_list[start_idx] - ) or is_non_compute_node(self.index_tracer.node_list[end_idx]): + self.trace_index.node_list[start_idx] + ) or is_non_compute_node(self.trace_index.node_list[end_idx]): continue # select free dim @@ -173,7 +176,7 @@ def _step_search(self, mem_peak, active_node, chunk_regions): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self.chunk_selector._select_best_chunk_region( + best_chunk_region = self.select_chunk._select_best_chunk_region( possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) @@ -191,8 +194,8 @@ def search_region(self): init_mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list + ) = self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list ) mem_peak = init_mem_peak @@ -206,14 +209,14 @@ def search_region(self): mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos + ) = self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list, chunk_infos ) if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: self.print_mem = False - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True + self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list, chunk_infos, print_mem=True ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index f0262f1e57eb..30f4226f54ec 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,4 +1,5 @@ -from .trace_index import TraceIndex, ReorderGraph +from .trace_index import TraceIndex +from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory from .utils import is_non_compute_node diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py new file mode 100644 index 000000000000..f372fa91335f --- /dev/null +++ b/colossalai/autochunk/trace_flow.py @@ -0,0 +1,414 @@ +from .trace_index import TraceIndex +from .utils import ( + find_chunk_all_input_nodes, + find_chunk_compute_input_and_output_nodes, + find_idx_by_name, + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) + + +class TraceFlow(object): + def __init__(self, trace_index: TraceIndex) -> None: + self.trace_index = trace_index + + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list) + end_node_trace = self.trace_index._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted( + end_node_trace_source.items(), key=lambda d: d[0], reverse=True + ) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and start_dim in node_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self.trace_index._find_trace_from_node(end_node) + end_node_compute = end_node_trace["compute"][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False + return True + + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self.trace_index._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list) + node_trace_source = self.trace_index._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] + ): + return node_dim + return None + + def check_index_duplicate(self, chunk_infos, return_dim=False): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim + + for node in self.trace_index.node_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if is_non_compute_node_except_placeholder(node): + continue + count = 0 + duplicate_dims = [] + node_trace_source = self.trace_index._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + + if count > 1: + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True + + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): + arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: + return False + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) + # else add it to list + else: + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def _get_all_node_info(self, end_dim, start_idx, end_idx): + cur_node_list = [ + self.trace_index.node_list[end_idx] + ] # start from the last node + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] + if cur_node_chunk_dim: + cur_node_compute = self.trace_index._find_compute_trace_from_node( + cur_node + ) + cur_node_source = self.trace_index._find_source_trace_from_node( + cur_node + ) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not ( + start_idx + <= find_idx_by_name(arg.name, self.trace_index.node_list) + < end_idx + ): + continue + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] + arg_shape = get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + return all_node_info + + def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + input_node_idx = find_idx_by_name( + input_node.name, self.trace_index.node_list + ) + for user in input_node.users.keys(): + if is_non_compute_node(user): + continue + user_idx = find_idx_by_name(user.name, self.trace_index.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]["chunk_dim"] + if chunk_dim is not None: + user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None, None + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + return inputs, inputs_dim + + def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info["chunk_dim"] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list), + reverse=True, + ) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not ( + start_idx + <= find_idx_by_name( + cur_prepose_node_arg.name, self.trace_index.node_list + ) + < end_idx + ): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list) + ) + + return prepose_nodes + + def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in chunk_info["args"]["prepose_nodes"]: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + return chunk_info + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.trace_index.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim( + inputs, start_idx, end_idx, all_node_info + ) + if inputs is None: + return None + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( + all_node_info, start_idx, end_idx + ) + + # find non chunk inputs + chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) + + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info["region"] + reshape_size = {} + chunk_shape = get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] + for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): + reshape_args = node.args[1:] + reshape_log = self.trace_index.idx_view_list[node] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log["dim_to"]: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) + chunk_info["reshape_size"] = reshape_size + return chunk_info diff --git a/colossalai/autochunk/trace_index.py b/colossalai/autochunk/trace_index.py index 3ac0d7f84272..1e8969d8796e 100644 --- a/colossalai/autochunk/trace_index.py +++ b/colossalai/autochunk/trace_index.py @@ -1,12 +1,8 @@ import copy from .utils import ( - find_chunk_all_input_nodes, - find_chunk_compute_input_and_output_nodes, find_idx_by_name, get_node_shape, - is_non_compute_node, - is_non_compute_node_except_placeholder, ) @@ -588,394 +584,3 @@ def trace_index(self): continue else: raise NotImplementedError(node.op, "op not implemented yet!") - # self._merge_equal_idx() - - def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): - """ - Check 2 given index: one index should be source of the other - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - start_node_idx = find_idx_by_name(start_node.name, self.node_list) - end_node_trace = self._find_trace_from_node(end_node) - end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) - for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and start_dim in node_dim: - return True - # it means we meet a node outside the loop, and the node is not input node - if node_idx < start_idx: - return False - return False - - def check_index_compute(self, start_idx, end_dim, end_node, end_idx): - """ - Check 2 given index: check they haven't been computed in the source trace. - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - end_node_trace = self._find_trace_from_node(end_node) - end_node_compute = end_node_trace["compute"][end_dim] - if any(start_idx <= i <= end_idx for i in end_node_compute): - return False - return True - - def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self._find_source_trace_from_node(node_from) - dim_source = node_from_source[node_from_dim] - node_to_idx = find_idx_by_name(node_to.name, self.node_list) - for k, v in dim_source.items(): - if k == node_to_idx: - return v - return None - - def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = find_idx_by_name(input_node.name, self.node_list) - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): - return node_dim - return None - - def check_index_duplicate(self, chunk_infos, return_dim=False): - input_dim_after_node = {} - for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): - for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) - if inherit_dim: - input_dim_after_node[k] = inherit_dim - - for node in self.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: - if is_non_compute_node_except_placeholder(node): - continue - count = 0 - duplicate_dims = [] - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - duplicate_dim = [] - duplicate_flag = False - dim_source = node_trace_source[node_dim] - for k, v in dim_source.items(): - if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] in v: - duplicate_flag = True - duplicate_dim.append((k, v)) - duplicate_dims.append(duplicate_dim) - if duplicate_flag: - count += 1 - - if count > 1: - if return_dim: - return False, duplicate_dims - else: - return False - if return_dim: - return True, None - else: - return True - - def _assgin_single_node_flow( - self, - arg_node, - start_idx, - end_idx, - cur_node_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ): - arg_idx = find_idx_by_name(arg_node.name, self.node_list) - # arg in chunk range or be inputs - if not (start_idx <= arg_idx < end_idx): - return True - - # find arg dim - if cur_node_dim is not None: - # dim is computed - if arg_idx in cur_node_compute[cur_node_dim]: - return False - if arg_idx not in cur_node_source[cur_node_dim]: - arg_dim = None - else: - arg_dim = cur_node_source[cur_node_dim][arg_idx][0] - else: - arg_dim = None - - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) - - # if already in node_info, arg dim must be same - if arg_node in all_node_info: - if all_node_info[arg_node]["chunk_dim"] != arg_dim: - return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) - # else add it to list - else: - all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} - - next_node_list.append(arg_node) - return True - - def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.node_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} - - while len(cur_node_list) > 0: - next_node_list = [] - - for cur_node in cur_node_list: - # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] - cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node(cur_node) - cur_node_source = self._find_source_trace_from_node(cur_node) - else: - cur_node_compute = cur_node_source = None - - # get all valid args - arg_list = [] - for arg in cur_node.args: - if type(arg) != type(cur_node): - continue - if is_non_compute_node(arg): - continue - arg_list.append(arg) - flow_flag = self._assgin_single_node_flow( - arg, - start_idx, - end_idx, - cur_node_chunk_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ) - if flow_flag == False: - return None - - if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): - for arg in arg_list: - if not ( - start_idx - <= find_idx_by_name(arg.name, self.node_list) - < end_idx - ): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif "einsum" in cur_node.name: - pass - elif "matmul" in cur_node.name: - pass - else: - raise NotImplementedError() - cur_node_list = next_node_list - return all_node_info - - def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): - inputs_dim = [] - remove_inputs = [] - for input_node in inputs: - input_dict = {} - input_node_idx = find_idx_by_name(input_node.name, self.node_list) - for user in input_node.users.keys(): - if is_non_compute_node(user): - continue - user_idx = find_idx_by_name(user.name, self.node_list) - if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]["chunk_dim"] - if chunk_dim is not None: - user_source = self._find_source_trace_from_node(user)[chunk_dim] - if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] - else: - return None, None - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - for i in remove_inputs: - if i in inputs: - inputs.remove(i) - return inputs, inputs_dim - - def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): - # get all possible prepose nodes - maybe_prepose_nodes = [] - for node, node_info in all_node_info.items(): - if node_info["chunk_dim"] is None: - maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.node_list), - reverse=True, - ) # from last node to first node - prepose_nodes = [] - # set every node as root, search its args, if all legal, turn root and args as prepose nodes - while len(maybe_prepose_nodes) > 0: - tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] - tmp_cur_related_prepose_nodes = [] - prepose_flag = True - - # loop cur node's all arg until out of chunk - while len(tmp_cur_prepose_nodes) > 0: - if prepose_flag == False: - break - tmp_next_prepose_nodes = [] - tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) - for cur_prepose_node in tmp_cur_prepose_nodes: - if prepose_flag == False: - break - for cur_prepose_node_arg in cur_prepose_node.args: - if type(cur_prepose_node_arg) != type(cur_prepose_node): - continue - # out of loop - if not ( - start_idx - <= find_idx_by_name( - cur_prepose_node_arg.name, self.node_list - ) - < end_idx - ): - continue - # compute op in loop - elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - else: - prepose_flag = False - break - # non compute op - else: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - tmp_cur_prepose_nodes = tmp_next_prepose_nodes - - if prepose_flag == False: - maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) - continue - else: - for n in tmp_cur_related_prepose_nodes: - if n not in prepose_nodes: - prepose_nodes.append(n) - if n in maybe_prepose_nodes: - maybe_prepose_nodes.remove(n) - # sort by index - prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) - - return prepose_nodes - - def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in chunk_info["args"]["prepose_nodes"]: - chunk_node_list.remove(n) - non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - return chunk_info - - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - - # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim( - inputs, start_idx, end_idx, all_node_info - ) - if inputs is None: - return None - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( - all_node_info, start_idx, end_idx - ) - - # find non chunk inputs - chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) - - # reassgin reshape size, some size may have changed due to chunk - chunk_info = self._reassgin_reshape_size(chunk_info) - - return chunk_info - - def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info["region"] - reshape_size = {} - chunk_shape = get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: - if any(i in node.name for i in ["reshape", "view"]): - reshape_args = node.args[1:] - reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] - reshape_size[node.name] = {} - for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log["dim_to"]: - continue - if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) - chunk_info["reshape_size"] = reshape_size - return chunk_info diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 081f01368a42..7a9d8cdeee03 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -104,8 +104,8 @@ def benchmark_evoformer(): model = evoformer_base().cuda() # build autochunk model - # max_memory = 1000 # MB fit memory mode - max_memory = None # min memory mode + max_memory = 1000 # MB fit memory mode + # max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold From c3a2bf48b447a5e051bcae5d694ff5dd7beda54a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:31:59 +0800 Subject: [PATCH 098/209] code style --- colossalai/autochunk/autochunk_codegen.py | 14 +++++----- colossalai/autochunk/reorder_graph.py | 33 ++++++++++++----------- colossalai/autochunk/search_chunk.py | 11 +++++--- colossalai/autochunk/select_chunk.py | 12 ++++----- colossalai/autochunk/trace_flow.py | 12 ++++++--- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 39728cb794f7..891753faae6d 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -103,7 +103,7 @@ def emit_code_with_chunk( nodes, emit_node_func, delete_unused_value_func, - chunk_region_search: SearchChunk, + search_chunk: SearchChunk, chunk_infos, ): """Emit code with nested activation checkpoint @@ -133,7 +133,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list) + node_list = search_chunk.reorder_graph.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False @@ -167,7 +167,7 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - meta_node = chunk_region_search.trace_index.node_list[node_idx] + meta_node = search_chunk.trace_index.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ "chunk_dim" ] @@ -220,10 +220,8 @@ def __init__(self, meta_graph, max_memory=None, print_mem=False): self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = SearchChunk( - meta_graph, max_memory, print_mem - ) - self.chunk_infos = self.chunk_region_search.search_region() + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.chunk_infos = self.search_chunk.search_region() def _gen_python_code( self, nodes, root_module: str, namespace: _Namespace @@ -458,7 +456,7 @@ def emit_node(node: Node, body): nodes, emit_node, delete_unused_values, - self.chunk_region_search, + self.search_chunk, self.chunk_infos, ) diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py index 7b9f4a20d6ab..bf4420eac7ee 100644 --- a/colossalai/autochunk/reorder_graph.py +++ b/colossalai/autochunk/reorder_graph.py @@ -3,28 +3,31 @@ class ReorderGraph(object): - def __init__(self, index_tracer: TraceIndex) -> None: - self.index_tracer = index_tracer - self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + def __init__(self, trace_index: TraceIndex) -> None: + self.trace_index = trace_index + self.all_reorder_map = { + i: i for i in range(len(self.trace_index.idx_trace_list)) + } def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} + reorder_map = {i: i for i in range(len(self.trace_index.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes + find_idx_by_name(i.name, self.trace_index.node_list) + for i in chunk_prepose_nodes ] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) + n_idx = find_idx_by_name(n.name, self.trace_index.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -50,25 +53,25 @@ def _update_all_reorder_map(self, reorder_map): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.index_tracer.node_list))] + new_node_list = [None for _ in range(len(self.trace_index.node_list))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.index_tracer.node_list[old_idx] - self.index_tracer.node_list = new_node_list + new_node_list[new_idx] = self.trace_index.node_list[old_idx] + self.trace_index.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] + new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))] for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] - self.index_tracer.idx_trace_list = new_idx_trace_list + new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx] + self.trace_index.idx_trace_list = new_idx_trace_list # update compute - for idx_trace in self.index_tracer.idx_trace_list: + for idx_trace in self.trace_index.idx_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source - for idx_trace in self.index_tracer.idx_trace_list: + for idx_trace in self.trace_index.idx_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {} diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 030b13bdb9c4..e2c8de74e012 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,10 +1,10 @@ import copy -from .select_chunk import SelectChunk -from .trace_index import TraceIndex -from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .select_chunk import SelectChunk from .trace_flow import TraceFlow +from .trace_index import TraceIndex from .utils import ( get_node_shape, is_non_compute_node, @@ -22,7 +22,10 @@ def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.reorder_graph = ReorderGraph(self.trace_index) self.estimate_memory = EstimateMemory() self.select_chunk = SelectChunk( - self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory + self.trace_index, + self.estimate_memory, + self.reorder_graph, + max_memory=max_memory, ) def _find_peak_node(self, mem_peak): diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 30f4226f54ec..bdc64528ef18 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,19 +1,19 @@ -from .trace_index import TraceIndex -from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .trace_index import TraceIndex from .utils import is_non_compute_node class SelectChunk(object): def __init__( self, - index_tracer: TraceIndex, - memory_estimator: EstimateMemory, + trace_index: TraceIndex, + estimate_memory: EstimateMemory, reorder_graph: ReorderGraph, max_memory=None, ): - self.index_tracer = index_tracer - self.memory_estimator = memory_estimator + self.index_tracer = trace_index + self.memory_estimator = estimate_memory self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index f372fa91335f..7139e7e047ef 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -81,7 +81,9 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k]) + inherit_dim = self._find_inherit_dim( + input_node, v, self.trace_index.node_list[k] + ) if inherit_dim: input_dim_after_node[k] = inherit_dim @@ -217,7 +219,9 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): for arg in arg_list: if not ( start_idx - <= find_idx_by_name(arg.name, self.trace_index.node_list) + <= find_idx_by_name( + arg.name, self.trace_index.node_list + ) < end_idx ): continue @@ -255,7 +259,9 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim] + user_source = self.trace_index._find_source_trace_from_node( + user + )[chunk_dim] if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: From 8a989a0d89418c308c1d97b4d692a4e753395732 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:55:22 +0800 Subject: [PATCH 099/209] code style --- colossalai/autochunk/autochunk_codegen.py | 69 +++++++++++++---------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 891753faae6d..0db2e59080dd 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict): return context +def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body): + if "ones_like" in node.name: + meta_node = search_chunk.trace_index.node_list[node_idx] + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) + return body + + +def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body): + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim[0], "chunk_idx", get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + return body + + def emit_code_with_chunk( body, nodes, @@ -156,36 +189,14 @@ def emit_code_with_chunk( if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) + body = _replace_input_var( + chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body + ) # ones like - if "ones_like" in node.name: - meta_node = search_chunk.trace_index.node_list[node_idx] - chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ - "chunk_dim" - ] - if get_node_shape(meta_node)[chunk_dim] != 1: - source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node][ - "chunk_dim" - ] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + body = _replace_ones_like( + search_chunk, chunk_infos, region_idx, node_idx, node, body + ) + # reassgin reshape size body[-1] = _replace_reshape_size( body[-1], node.name, chunk_infos[region_idx]["reshape_size"] ) From 4d223e18a2600ca2467fb21ef4c18f0e9aa0d04c Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 13:46:17 +0800 Subject: [PATCH 100/209] fix typo --- colossalai/autochunk/{estiamte_memory.py => estimate_memory.py} | 0 colossalai/autochunk/search_chunk.py | 2 +- colossalai/autochunk/select_chunk.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename colossalai/autochunk/{estiamte_memory.py => estimate_memory.py} (100%) diff --git a/colossalai/autochunk/estiamte_memory.py b/colossalai/autochunk/estimate_memory.py similarity index 100% rename from colossalai/autochunk/estiamte_memory.py rename to colossalai/autochunk/estimate_memory.py diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index e2c8de74e012..21b967497f1b 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,6 +1,6 @@ import copy -from .estiamte_memory import EstimateMemory +from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index bdc64528ef18..7127cfd64e69 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,4 +1,4 @@ -from .estiamte_memory import EstimateMemory +from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph from .trace_index import TraceIndex from .utils import is_non_compute_node From cb68ee864a21e330e8061ee13811a7045f3d65f3 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:20:41 +0800 Subject: [PATCH 101/209] set benchmark --- tests/test_autochunk/benchmark_autochunk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 7a9d8cdeee03..6632ece61376 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -98,14 +98,14 @@ def _build_openfold(): def benchmark_evoformer(): # init data and model msa_len = 256 - pair_len = 256 + pair_len = 512 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = evoformer_base().cuda() # build autochunk model - max_memory = 1000 # MB fit memory mode - # max_memory = None # min memory mode + # max_memory = 1000 # MB, fit memory mode + max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold From 18a51c87fe0aa3a1210d7484fc09c16714e04bb7 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:20:54 +0800 Subject: [PATCH 102/209] rename test --- .../{test_autochunk.py => test_autochunk_codegen.py} | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) rename tests/test_autochunk/{test_autochunk.py => test_autochunk_codegen.py} (97%) diff --git a/tests/test_autochunk/test_autochunk.py b/tests/test_autochunk/test_autochunk_codegen.py similarity index 97% rename from tests/test_autochunk/test_autochunk.py rename to tests/test_autochunk/test_autochunk_codegen.py index 85a162084cc9..1c5dd939d710 100644 --- a/tests/test_autochunk/test_autochunk.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -18,9 +18,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 with torch.no_grad(): - node1 = node.clone() - pair1 = pair.clone() - gm(node1, pair1) + gm(node.clone(), pair.clone()) new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 print( From 74b81395a2edbce36896f3d184c6cfae327024b5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:26:22 +0800 Subject: [PATCH 103/209] update codegen test --- .../test_autochunk/test_autochunk_codegen.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 1c5dd939d710..8246275eb08a 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -15,16 +15,19 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - torch.cuda.reset_peak_memory_stats() - now_mem = torch.cuda.memory_allocated() / 1024**2 - with torch.no_grad(): - gm(node.clone(), pair.clone()) - new_now_mem = torch.cuda.memory_allocated() / 1024**2 - new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print( - "autochunk now mem:%.2f max mem:%.2f" - % (new_now_mem - now_mem, new_max_mem - now_mem) - ) + # for memory test + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # gm(node1, pair1) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print( + # "autochunk now mem:%.2f max mem:%.2f" + # % (new_now_mem - now_mem, new_max_mem - now_mem) + # ) # test forward with torch.no_grad(): @@ -43,7 +46,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): ) -def _run_offload_codegen(rank): +def _test_autochunk_codegen(rank): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly colossalai.launch( config={}, @@ -56,8 +59,10 @@ def _run_offload_codegen(rank): # build model and input model = evoformer_base().cuda() - node = torch.randn(1, 100, 300, 256).cuda() - pair = torch.randn(1, 300, 300, 128).cuda() + msa_len = 32 + pair_len = 64 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() # trace the module and replace codegen graph = ColoTracer().trace( @@ -85,17 +90,18 @@ def _run_offload_codegen(rank): gm = ColoGraphModule(model, graph) gm.recompile() - # assert we have all the components - # code = graph.python_code("self").src + # assert we have inserted chunk + code = graph.python_code("self").src + assert "chunk_size" in code # print(code) _test_fwd(model, gm, node, pair) gpc.destroy() -def test_autochunk(): - mp.spawn(_run_offload_codegen, nprocs=1) +def test_autochunk_codegen(): + mp.spawn(_test_autochunk_codegen, nprocs=1) if __name__ == "__main__": - _run_offload_codegen(0) + _test_autochunk_codegen(0) From 9880fd2cd8b3b24c28333926338656a06dd170f3 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Mon, 9 Jan 2023 14:35:14 +0800 Subject: [PATCH 104/209] Fix state_dict key missing issue of the ZeroDDP (#2363) * Fix state_dict output for ZeroDDP duplicated parameters * Rewrite state_dict based on get_static_torch_model * Modify get_static_torch_model to be compatible with the lower version (ZeroDDP) --- colossalai/nn/parallel/data_parallel.py | 37 +++++++++++++++++++++---- colossalai/nn/parallel/utils.py | 16 +++++------ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index e3bb83347d21..8fd08db957b7 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -18,6 +18,7 @@ from colossalai.zero.utils.gemini_hook import GeminiZeROHook from .reducer import Reducer +from .utils import get_static_torch_model try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -251,6 +252,7 @@ def __init__(self, pin_memory=pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + self.chunk_manager.close_all_groups() self._cast_buffers() @@ -331,12 +333,11 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - r"""Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): + r""" + Args: + strict (bool): whether to reture the whole model state + as the original pytorch state_dict() Returns: dict: @@ -346,7 +347,31 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: >>> module.state_dict().keys() ['bias', 'weight'] + """ + if strict: + return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), + only_rank_0=only_rank_0).state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars) + return self._non_strict_state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars, + only_rank_0=only_rank_0) + + def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + r"""Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + Warning: The non strict state dict would ignore the parameters if the + tensors of the parameters are shared with other parameters which + have been included in the dictionary. + + Returns: + dict: + a dictionary containing a whole state of the module """ if destination is None: destination = OrderedDict() diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 1205cbc3a658..988f978254a1 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module): return name_to_module[''] -def get_static_torch_model(gemini_ddp_model, +def get_static_torch_model(zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True) -> torch.nn.Module: - """Get a static torch.nn.Module model from the given GeminiDDP module. - You should notice that the original GeminiDDP model is not modified. + """Get a static torch.nn.Module model from the given ZeroDDP module. + You should notice that the original ZeroDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors. Args: - gemini_ddp_model (GeminiDDP): a gemini ddp model + zero_ddp_model (ZeroDDP): a zero ddp model device (torch.device): the device of the final torch model dtype (torch.dtype): the dtype of the final torch model only_rank_0 (bool): if True, only rank0 has the coverted torch model @@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.nn.parallel import GeminiDDP - assert isinstance(gemini_ddp_model, GeminiDDP) + from colossalai.nn.parallel import ZeroDDP + assert isinstance(zero_ddp_model, ZeroDDP) - state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0) - colo_model = gemini_ddp_model.module + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) + colo_model = zero_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: From 3abbaf8bc68c8a3366241a3dc2e97f6944605fb2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:53:04 +0800 Subject: [PATCH 105/209] update codegen test --- .../test_autochunk/test_autochunk_codegen.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 8246275eb08a..c91148e11ff8 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest import torch import torch.fx @@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): ) -def _test_autochunk_codegen(rank): +def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly colossalai.launch( config={}, @@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank): # build model and input model = evoformer_base().cuda() - msa_len = 32 - pair_len = 64 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() @@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank): MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) - codegen = AutoChunkCodeGen(gm_prop) + codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() @@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank): gpc.destroy() -def test_autochunk_codegen(): - mp.spawn(_test_autochunk_codegen, nprocs=1) +@pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_autochunk_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_autochunk_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) if __name__ == "__main__": - _test_autochunk_codegen(0) + _test_autochunk_codegen(0, 32, 64, None) From a005965d2d5f506aafe672575388501bfc5dc5d8 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:57:47 +0800 Subject: [PATCH 106/209] update codegen test --- tests/test_autochunk/test_autochunk_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index c91148e11ff8..62763a6d5e2a 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -99,7 +99,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32]) +@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) def test_autochunk_codegen(msa_len, pair_len, max_memory): From d106b271f8fa8968bfa7a5f7652448c41f26c260 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 15:19:08 +0800 Subject: [PATCH 107/209] add chunk search test --- tests/test_autochunk/test_autochunk_search.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tests/test_autochunk/test_autochunk_search.py diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py new file mode 100644 index 000000000000..c824a43ab612 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_search.py @@ -0,0 +1,86 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +import colossalai +from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import MetaTensor +from colossalai.utils import free_port +from tests.test_autochunk.evoformer.evoformer import evoformer_base + + +def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): + found_regions = [i["region"] for i in chunk_infos] + + if msa_len == 32 and pair_len == 64: + if max_memory is None: + target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)] + elif max_memory == 20: + target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)] + elif max_memory == 25: + target_regions = [(144, 154), (369, 370)] + elif max_memory == 30: + target_regions = [(144, 154)] + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions)) + for region in target_regions: + assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + for region in found_regions: + assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + + +def _test_autochunk_search(rank, msa_len, pair_len, max_memory): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = evoformer_base().cuda() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) + + codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) + chunk_infos = codegen.chunk_infos + assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len) + + gpc.destroy() + + +@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_autochunk_search(msa_len, pair_len, max_memory): + run_func = partial( + _test_autochunk_search, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_autochunk_search(0, 32, 64, 20) From d5c4f0bf954a5686777f652e34b5cd18df2a0d5a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 15:22:09 +0800 Subject: [PATCH 108/209] code style --- tests/test_autochunk/test_autochunk_search.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index c824a43ab612..6f7214633fa3 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -8,8 +8,6 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc -from colossalai.fx import ColoTracer -from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port @@ -32,12 +30,31 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): raise NotImplementedError() else: raise NotImplementedError() - - assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions)) + + assert len(found_regions) == len( + target_regions + ), "len of found regions %s doesn't equal len of target regions %s" % ( + str(found_regions), + str(target_regions), + ) for region in target_regions: - assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + assert ( + region in found_regions + ), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( + str(region), + msa_len, + pair_len, + max_memory, + ) for region in found_regions: - assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + assert ( + region in target_regions + ), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( + str(region), + msa_len, + pair_len, + max_memory, + ) def _test_autochunk_search(rank, msa_len, pair_len, max_memory): From aafc3516a5c07347f58bbc1a52410f74e51b685f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 15:32:19 +0800 Subject: [PATCH 109/209] add available --- tests/test_autochunk/test_autochunk_codegen.py | 2 ++ tests/test_autochunk/test_autochunk_search.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 62763a6d5e2a..c4f5cda67204 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -9,6 +9,7 @@ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor @@ -99,6 +100,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() +@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0') @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index 6f7214633fa3..5026c3ad3b3d 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -8,6 +8,7 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port @@ -86,6 +87,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): gpc.destroy() +@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason="torch version is lower than 1.12.0") @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) From 498b5ca993fb17eccdfbe7608f36444d5779f0c8 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 9 Jan 2023 15:52:17 +0800 Subject: [PATCH 110/209] [hotfix] fix gpt gemini example (#2404) * [hotfix] fix gpt gemini example * [example] add new assertions --- .../language/gpt/gemini/benchmark_gemini.sh | 30 ++++++++++--------- .../language/gpt/gemini/train_gpt_demo.py | 2 ++ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/language/gpt/gemini/benchmark_gemini.sh b/examples/language/gpt/gemini/benchmark_gemini.sh index 13086666eefd..464ea03da7eb 100644 --- a/examples/language/gpt/gemini/benchmark_gemini.sh +++ b/examples/language/gpt/gemini/benchmark_gemini.sh @@ -1,18 +1,20 @@ for MODEL_TYPE in "gpt2_medium"; do - for BATCH_SIZE in 16; do - for GPUNUM in 1 2 4 8; do - for TPDEGREE in 1 2 4 8; do - if [ ${TPDEGREE} -gt ${GPUNUM} ]; then - continue - fi - for PLACEMENT in "cpu" "auto"; do - echo "****************** Begin ***************************" - echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}" - MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ - bash ./gemini/run_gemini.sh - echo "****************** Finished ***************************" - echo "" - echo "" + for DISPAN in "colossalai"; do + for BATCH_SIZE in 16; do + for GPUNUM in 1 2 4 8; do + for TPDEGREE in 1 2 4 8; do + if [ ${TPDEGREE} -gt ${GPUNUM} ]; then + continue + fi + for PLACEMENT in "cpu" "auto"; do + echo "****************** Begin ***************************" + echo "+ benchmrking MODEL ${MODEL_TYPE} DISPAN ${DISPAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}" + MODEL_TYPE=${MODEL_TYPE} DISPAN=${DISPAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ + bash ./run_gemini.sh + echo "****************** Finished ***************************" + echo "" + echo "" + done done done done diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 29f8c8ef1215..891b1de15af1 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -270,6 +270,7 @@ def main(): tp_pg = ProcessGroup(tp_degree=args.tp_degree) # Tensor Parallelism (TP) + # You should notice that v0.1.10 is not compatible with TP degree > 1 tensor_parallelize(model, tp_pg) # build a Gemini model and a highly optimized cpu optimizer @@ -278,6 +279,7 @@ def main(): logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) else: + assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() if args.distplan.startswith("torch"): From 19cc64b1d39529bde502f9507d20770430f6e3af Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:06:58 +0800 Subject: [PATCH 111/209] remove autochunk_available --- colossalai/autochunk/autochunk_codegen.py | 490 +++++++++++----------- 1 file changed, 239 insertions(+), 251 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 0db2e59080dd..9ec59477b426 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -16,13 +16,9 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg import colossalai - from .search_chunk import SearchChunk from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape -CODEGEN_AVAILABLE = True -__all__ = ["AutoChunkCodeGen"] - def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): new_shape = "[" @@ -222,287 +218,279 @@ def emit_code_with_chunk( node_idx += 1 -if CODEGEN_AVAILABLE: - - class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None, print_mem=False): - super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) - # find the chunk regions - self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) - self.chunk_infos = self.search_chunk.search_region() +class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None, print_mem=False): + super().__init__() + self.meta_graph = meta_graph + self.max_memory = max_memory + self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.chunk_infos = self.search_chunk.search_region() - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. - We call this for names that reference objects external to the - Graph, like functions or types. + We call this for names that reference objects external to the + Graph, like functions or types. - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj return global_name + globals_[global_name] = obj + return global_name - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return "()" + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) - typename = _type_repr(o) + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) + typename = _type_repr(o) - if hasattr(o, "__args__"): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] - return f'{origin_typename}[{",".join(args)}]' - else: + if len(args) == 0: # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ + # This code-path used in Python < 3.9 return origin_typename - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, "_fields"): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - - args_s = ", ".join(_get_repr(a) for a in args) - kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) - if args_s and kwargs_s: - return f"{args_s}, {kwargs_s}" - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - delete_free_var_from_last_use(user_to_last_uses) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body, to_keep=[]): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == "placeholder": - return - if user.op == "output": - body.append("\n") - return - nodes_to_delete = user_to_last_uses.get(user, []) - nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] - if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) - body.append(f"; {to_delete_str}\n") + return f'{origin_typename}[{",".join(args)}]' else: - body.append("\n") + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") + else: + body.append("\n") - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" ) - if node.op == "placeholder": - assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) - raw_name = node.target.replace("*", "") - if raw_name != repr(node): - body.append(f"{repr(node)} = {raw_name}\n") - return - elif node.op == "call_method": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) - return - elif node.op == "call_function": - assert callable(node.target) - # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): - assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) - return + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" ) - if node.meta.get("is_wrapped", False): - wrapped_fns.setdefault(global_name) return - elif node.op == "call_module": - assert isinstance(node.target, str) + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" ) return - elif node.op == "get_attr": - assert isinstance(node.target, str) + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" ) return - elif node.op == "output": - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f"node: {node.op} {node.target}") - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.search_chunk, - self.chunk_infos, - ) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append("pass\n") - - if len(wrapped_fns) > 0: - wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" ) - else: - wrap_stmts = "" + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk( + body, + nodes, + emit_node, + delete_unused_values, + self.search_chunk, + self.chunk_infos, + ) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = "" - if self._body_transformer: - body = self._body_transformer(body) + if self._body_transformer: + body = self._body_transformer(body) - for name, value in self.additional_globals(): - add_global(name, value) + for name, value in self.additional_globals(): + add_global(name, value) - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = "".join(ckpt_func) + prologue - prologue = prologue + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue - code = "".join(body) - code = "\n".join(" " + line for line in code.split("\n")) - fn_code = f""" + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" {wrap_stmts} {prologue} {code}""" - # print(fn_code) - return PythonCode(fn_code, globals_) + # print(fn_code) + return PythonCode(fn_code, globals_) From d3f5ce9efb35bf9e292aa041a3e98b737cbb68ee Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 9 Jan 2023 16:21:44 +0800 Subject: [PATCH 112/209] [workflow] added nightly release to pypi (#2403) --- .github/workflows/release_nightly.yml | 86 +++++++-------------------- setup.py | 30 ++++++++-- 2 files changed, 45 insertions(+), 71 deletions(-) diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index 6bc000d1f4f6..8aa48b8ed89e 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -1,73 +1,29 @@ -name: Release bdist wheel for Nightly versions +name: Publish Nightly Version to PyPI on: - schedule: - # run at 00:00 of every Sunday - - cron: '0 0 * * 6' workflow_dispatch: + schedule: + - cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time jobs: - matrix_preparation: - name: Prepare Container List + build-n-publish: + if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' + name: Build and publish Python 🐍 distributions 📦 to PyPI runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} + timeout-minutes: 20 steps: - - id: set-matrix - run: | - matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]" - echo $matrix - echo "::set-output name=matrix::{\"container\":$(echo $matrix)}" + - uses: actions/checkout@v2 - build: - name: Release bdist wheels - needs: matrix_preparation - if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) - runs-on: [self-hosted, gpu] - strategy: - fail-fast: false - matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} - container: - image: ${{ matrix.container }} - options: --gpus all --rm - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - # cub is for cuda 10.2 - - name: Copy scripts and checkout - run: | - cp -r ./.github/workflows/scripts/* ./ - ln -s /github/home/pip_wheels ./pip_wheels - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - - name: Build bdist wheel - run: | - pip install beautifulsoup4 requests packaging - python ./build_colossalai_wheel.py --nightly - - name: 🚀 Deploy - uses: garygrossgarten/github-action-scp@release - with: - local: all_dist - remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} - host: ${{ secrets.PRIVATE_PYPI_HOST }} - username: ${{ secrets.PRIVATE_PYPI_USER }} - password: ${{ secrets.PRIVATE_PYPI_PASSWD }} - remove_old_build: - name: Remove old nightly build - runs-on: ubuntu-latest - needs: build - steps: - - name: executing remote ssh commands using password - uses: appleboy/ssh-action@master - env: - BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} - with: - host: ${{ secrets.PRIVATE_PYPI_HOST }} - username: ${{ secrets.PRIVATE_PYPI_USER }} - password: ${{ secrets.PRIVATE_PYPI_PASSWD }} - envs: BUILD_DIR - script: | - cd $BUILD_DIR - find . -type f -mtime +0 -exec rm -f {} + - script_stop: true + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: NIGHTLY=1 python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + verbose: true diff --git a/setup.py b/setup.py index 38d5fa91cecd..5128b80e880d 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import os import re +from datetime import datetime from setuptools import find_packages, setup @@ -20,18 +21,22 @@ TORCH_AVAILABLE = False CUDA_HOME = None - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) build_cuda_ext = False ext_modules = [] +is_nightly = int(os.environ.get('NIGHTLY', '0')) == 1 if int(os.environ.get('CUDA_EXT', '0')) == 1: if not TORCH_AVAILABLE: - raise ModuleNotFoundError("PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions") + raise ModuleNotFoundError( + "PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) if not CUDA_HOME: - raise RuntimeError("CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions") + raise RuntimeError( + "CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions" + ) build_cuda_ext = True @@ -139,8 +144,16 @@ def get_version(): print(f'===== Building Extension {name} =====') ext_modules.append(builder_cls().builder()) -setup(name='colossalai', - version=get_version(), +if is_nightly: + # use date as the nightly version + version = datetime.today().strftime('%Y.%m.%d') + package_name = 'colossalai-nightly' +else: + version = get_version() + package_name = 'colossalai' + +setup(name=package_name, + version=version, packages=find_packages(exclude=( 'benchmark', 'docker', @@ -179,4 +192,9 @@ def get_version(): 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: System :: Distributed Computing', ], - package_data={'colossalai': ['_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', 'kernel/cuda_native/csrc/kernels/include/*']}) + package_data={ + 'colossalai': [ + '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', + 'kernel/cuda_native/csrc/kernels/include/*' + ] + }) From 212b5b1b5f4f3debf983d8c47c58af507a554be4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:29:33 +0800 Subject: [PATCH 113/209] add comments --- colossalai/autochunk/autochunk_codegen.py | 35 +++++++++++-------- .../test_autochunk/test_autochunk_codegen.py | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 9ec59477b426..5ef560ac209a 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple import torch from torch.fx.graph import ( @@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod def emit_code_with_chunk( - body, - nodes, + body: List[str], + nodes: Iterable[Node], emit_node_func, delete_unused_value_func, search_chunk: SearchChunk, - chunk_infos, + chunk_infos: List, ): - """Emit code with nested activation checkpoint - When we detect some of the node.activation_checkpoint is a List, we will use - this function to emit the activation checkpoint codes. + """ + Emit code with chunk according to chunk_infos. + + It will generate a for loop in chunk regions, and replace inputs + and outputs of regions with chunked variables. Args: body: forward code - ckpt_func: checkpoint functions code nodes: graph.nodes emit_node_func: function to emit node delete_unused_value_func: function to remove the unused value + search_chunk: the class to search all chunks + chunk_infos: store all information about all chunks. """ node_list = list(nodes) - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] + # chunk region + chunk_starts = [i["region"][0] for i in chunk_infos] + chunk_ends = [i["region"][1] for i in chunk_infos] - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + # chunk inputs + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] + # chunk outputs chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] @@ -170,6 +175,7 @@ def emit_code_with_chunk( while node_idx < len(node_list): node = node_list[node_idx] + # if is chunk start, generate for loop start if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) @@ -203,6 +209,7 @@ def emit_code_with_chunk( if node_idx not in chunk_inputs: delete_unused_value_func(node, body, chunk_inputs_names) + # generate chunk region end if node_idx in chunk_ends: body.append( _gen_loop_end( diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index c4f5cda67204..53f62077c07a 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -115,4 +115,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_codegen(0, 32, 64, None) + _test_autochunk_codegen(0, 32, 64, 25) From 1951f7fa87725b6cc719226d26e5734958adffac Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:30:16 +0800 Subject: [PATCH 114/209] code style --- colossalai/autochunk/autochunk_codegen.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 5ef560ac209a..cc39e391e4be 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -137,9 +137,9 @@ def emit_code_with_chunk( ): """ Emit code with chunk according to chunk_infos. - - It will generate a for loop in chunk regions, and replace inputs - and outputs of regions with chunked variables. + + It will generate a for loop in chunk regions, and + replace inputs and outputs of regions with chunked variables. Args: body: forward code @@ -156,9 +156,11 @@ def emit_code_with_chunk( chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [ + i["inputs_non_chunk"] for i in chunk_infos + ] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] From a68d240ed56dcd62a0726621c50233f733e79367 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:54:08 +0800 Subject: [PATCH 115/209] add doc for search chunk --- colossalai/autochunk/search_chunk.py | 76 ++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 21b967497f1b..613c28454df3 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,4 +1,7 @@ import copy +from typing import Any, Dict, Iterable, List, Tuple + +from torch.fx.node import Node from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph @@ -13,6 +16,34 @@ class SearchChunk(object): + """ + This is the core class for AutoChunk. + + It defines the framework of the strategy of AutoChunk. + Chunks will be selected one by one utill search stops. + + The chunk search is as follows: + 1. find the peak memory node + 2. find the max chunk region according to the peak memory node + 3. find all possible chunk regions in the max chunk region + 4. find the best chunk region for current status + 5. goto 1 + + Attributes: + gm: graph model + print_mem (bool): print estimated memory + trace_index: trace the flow of every dim of every node to find all free dims + trace_flow: determine the region chunk strategy + reorder_graph: reorder nodes to improve chunk efficiency + estimate_memory: estimate memory with chunk + select_chunk: select the best chunk region + + Args: + gm: graph model + max_memory (int): max memory in MB + print_mem (bool): print estimated memory + """ + def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem @@ -33,24 +64,37 @@ def _find_peak_node(self, mem_peak): max_idx = mem_peak.index(max_value) return max_idx - def _get_free_var(self): + def _get_free_var_idx(self) -> List: + """ + Get free var index + + Returns: + free_var_idx (List): all indexs of free vars + """ free_var_idx = [] for idx, n in enumerate(self.trace_index.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx - def _get_min_free_var(self, active_node_list, free_vars): - min_len = 999 - for idx, n in enumerate(active_node_list): - if idx in free_vars: - continue - if len(n) < min_len: - min_len = len(n) - return min_len + def _search_max_chunk_region( + self, active_node: List, peak_node: Node, chunk_regions: List + ) -> Tuple: + """ + Search max chunk region according to peak memory node + + Chunk region starts extending from the peak node, stops where free var num is min - def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): - free_vars = self._get_free_var() + Args: + active_node (List): active node status for every node + peak_node (Node): peak memory node + chunk_regions (List): chunk region info + + Returns: + chunk_region_start (int) + chunk_region_end (int) + """ + free_vars = self._get_free_var_idx() free_var_num = len(free_vars) active_node_num = [len(i) for i in active_node] min_active_node_num = min(active_node_num[free_var_num:]) @@ -92,16 +136,6 @@ def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end - def _is_not_compute(self, trace, chunk_range, dim_idx): - if trace["idx"][dim_idx] not in trace["compute"]: - return True - if trace["idx"][dim_idx] in trace["compute"] and all( - i < chunk_range[0] or i > chunk_range[1] - for i in trace["compute"][trace["idx"][dim_idx]] - ): - return True - return False - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] From 85e045b063a70cd36ccc0405acc245d86f2a1621 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 9 Jan 2023 17:08:55 +0800 Subject: [PATCH 116/209] [doc] updated readme regarding pypi installation (#2406) --- README-zh-Hans.md | 46 ++++++++++++++++++++++++++++++++++------------ README.md | 28 ++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 8edcff28bf04..b97b02f5ab84 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -5,10 +5,10 @@ Colossal-AI: 一个面向大模型时代的通用深度学习系统 -

论文 | - 文档 | - 例程 | - 论坛 | +

论文 | + 文档 | + 例程 | + 论坛 | 博客

[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) @@ -35,7 +35,7 @@
  • 为何选择 Colossal-AI
  • 特点
  • - 并行训练样例展示 + 并行训练样例展示
  • - 单GPU训练样例展示 + 单GPU训练样例展示
  • - 推理 (Energon-AI) 样例展示 + 推理 (Energon-AI) 样例展示
  • - Colossal-AI 成功案例 + Colossal-AI 成功案例