From 84734b11747a2ecc4cc9f11860efd176b52e9dd5 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Sat, 20 Aug 2022 16:38:31 +0800 Subject: [PATCH 01/13] [fx] fix defining ckpt functions inside forward --- .../codegen/activation_checkpoint_codegen.py | 47 ++++++++++++------- .../test_activation_checkpoint_codegen.py | 15 +++--- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 53eb46529113..065b824d84e1 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -89,7 +89,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: """ Generate the checkpoint function definition """ - return f"def checkpoint_{label}({', '.join(free_vars)}):" + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" def _gen_ckpt_output(output_vars: List[str]) -> str: @@ -105,10 +105,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, self, {inputs}, use_reentrant={use_reentrant})' -def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): # find the activation checkpoint regions ckpt_regions = _find_ckpt_regions(nodes) start_idx = [item[0] for item in ckpt_regions] @@ -133,19 +133,19 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu if idx in start_idx: label = start_idx.index(idx) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - body.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f'{ckpt_fn_def}\n') within_ckpt_region = True # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one - emit_node_func(node) - - # add indentation to the emmited node + # NOTE: currently we separate body and ckpt_func definition if within_ckpt_region: - body[-1] = ' ' + body[-1] - - # delete unused values - delete_unused_value_func(node) + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) if idx in end_idx: # if this is the last node of the ckpt region @@ -153,7 +153,7 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu label = end_idx.index(idx) return_statement = _gen_ckpt_output(output_vars[label]) return_statement = f' {return_statement}\n' - body.append(return_statement) + ckpt_func.append(return_statement) # we need to check if the checkpoint need to offload the input start_node_idx = start_idx[label] @@ -287,7 +287,8 @@ def register_last_uses(n: Node, user: Node): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - def delete_unused_values(user: Node): + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage @@ -305,7 +306,8 @@ def delete_unused_values(user: Node): else: body.append('\n') - def emit_node(node: Node): + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' if node.op == 'placeholder': assert isinstance(node.target, str) @@ -371,7 +373,8 @@ def emit_node(node: Node): raise NotImplementedError(f'node: {node.op} {node.target}') # Modified for activation checkpointing - emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values) + ckpt_func = [] + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -395,6 +398,7 @@ def emit_node(node: Node): # in forward function # TODO: Remove inline import prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue prologue = prologue + "\n import colossalai" code = ''.join(body) @@ -404,6 +408,7 @@ def emit_node(node: Node): {prologue} {code}""" + print(fn_code) return PythonCode(fn_code, globals_) else: @@ -484,7 +489,8 @@ def register_last_uses(n: Node, user: Node): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - def delete_unused_values(user: Node): + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage @@ -502,7 +508,8 @@ def delete_unused_values(user: Node): else: body.append('\n') - def emit_node(node: Node): + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' if node.op == 'placeholder': assert isinstance(node.target, str) @@ -562,7 +569,8 @@ def emit_node(node: Node): raise NotImplementedError(f'node: {node.op} {node.target}') # Modified for activation checkpointing - emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values) + ckpt_func = [] + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -587,6 +595,8 @@ def emit_node(node: Node): else: wrap_stmts = '' + ckpt_func = ''.join(ckpt_func) + # If the original function didn't have self as its first argument, we # would have added it. if len(orig_args) == 0 or orig_args[0] != 'self': @@ -600,6 +610,7 @@ def emit_node(node: Node): fn_code = f""" {wrap_stmts} +{ckpt_func} def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: import colossalai {code}""" diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 9c1bc57a3973..ee609a9539d8 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -46,7 +46,7 @@ def __init__(self): super().__init__() self.mlp1 = MLP() self.relu = relu() - self.linear3 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) def forward(self, x): y1, y2 = checkpoint(self.mlp1, x) @@ -56,6 +56,7 @@ def ckpt2(x): return F.relu(x, inplace=True) y4 = checkpoint(ckpt2, x) + y4 = self.linear2(y4) return y1 + y2 + y3 + y4 @@ -97,9 +98,9 @@ def _run_act_ckpt_codegen(rank): # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, self, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, self, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, self, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent fx_out = gm(data) @@ -150,9 +151,9 @@ def _run_act_ckpt_python_code_torch11(rank): # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, self, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, self, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, self, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent fx_out = gm(data) From e96b8d4449aa282d30ab1308ac469cce5239bb65 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 14:31:23 +0800 Subject: [PATCH 02/13] [fx] Modify activation checkpoint codegen and add ColoGraphModule --- .../codegen/activation_checkpoint_codegen.py | 15 +- colossalai/fx/graph_module.py | 140 ++++++++++++++++++ .../test_ckpt_torchvision.py | 10 +- .../test_activation_checkpoint_codegen.py | 19 ++- 4 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 colossalai/fx/graph_module.py diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 065b824d84e1..749c42d1512f 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,14 +1,19 @@ +import colossalai import torch from typing import List, Callable, Any, Tuple, Dict try: from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin CODEGEN_AVAILABLE = True + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) except: - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name CODEGEN_AVAILABLE = False + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) if CODEGEN_AVAILABLE: __all__ = ['ActivationCheckpointCodeGen'] @@ -105,7 +110,7 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, self, {inputs}, use_reentrant={use_reentrant})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): @@ -399,7 +404,7 @@ def emit_node(node: Node, body): # TODO: Remove inline import prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = ''.join(ckpt_func) + prologue - prologue = prologue + "\n import colossalai" + prologue = prologue code = ''.join(body) code = '\n'.join(' ' + line for line in code.split('\n')) @@ -408,7 +413,6 @@ def emit_node(node: Node, body): {prologue} {code}""" - print(fn_code) return PythonCode(fn_code, globals_) else: @@ -612,6 +616,5 @@ def emit_node(node: Node, body): {ckpt_func} def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: - import colossalai {code}""" return PythonCode(fn_code, globals_) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py new file mode 100644 index 000000000000..5382f7fd4491 --- /dev/null +++ b/colossalai/fx/graph_module.py @@ -0,0 +1,140 @@ +import os +import warnings +import torch +import torch.nn as nn +from torch.nn.modules.module import _addindent +from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src +from typing import Type, Dict, List, Any, Union, Optional, Set +from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode +from pathlib import Path + + +class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) + + def bind(self, ckpt_def, globals): + """Bind checkpoint functions to ColoGraphModule + We need to bind our checkpoint functions to the GraphModule so + that we could correctly use self.checkpoint for GraphModule forward + """ + ckpt_code = "\n".join(ckpt_def) + globals_copy = globals.copy() + _exec_with_source(ckpt_code, globals_copy) + func_list = [func for func in globals_copy.keys() if "checkpoint" in func] + for func in func_list: + tmp_func = globals_copy[func] + setattr(self, func, tmp_func.__get__(self, self.__class__)) + del globals_copy[func] + + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module='self') + self._code = python_code.src + + # To split ckpt functions code and forward code + _code_list = self._code.split("\n") + _fwd_def = [item for item in _code_list if "def forward" in item][0] + _fwd_idx = _code_list.index(_fwd_def) + ckpt_def = _code_list[:_fwd_idx] + self._code = "\n".join(_code_list[_fwd_idx:]) + + self.bind(ckpt_def, python_code.globals) + + cls = type(self) + cls.forward = _forward_from_src(self._code, python_code.globals) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if '_wrapped_call' not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped + + # reset self._code to original src, otherwise to_folder will be wrong + self._code = python_code.src + return python_code + + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + + # we add import colossalai here + model_str = f""" +import torch +from torch.nn import * +import colossalai +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__( +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index e57fa5f12921..a2c2a6e71f2a 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -7,6 +7,7 @@ from torch.fx import GraphModule import colossalai from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.algorithms import chen_greedy from colossalai.utils import free_port @@ -72,7 +73,7 @@ def _run_ckpt_solver(rank): for model_cls in MODEL_LIST: m = model_cls(num_classes=5) graph = tracer.trace(root=m) - gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) @@ -102,7 +103,7 @@ def _run_ckpt_solver_torch11(rank): for model_cls in MODEL_LIST: m = model_cls(num_classes=5) graph = tracer.trace(root=m) - gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm = solver(gm) @@ -119,5 +120,6 @@ def test_ckpt_solver_torch11(): if __name__ == '__main__': - test_ckpt_solver() - test_ckpt_solver_torch11() + _run_ckpt_solver(rank=0) + # test_ckpt_solver() + # test_ckpt_solver_torch11() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index ee609a9539d8..d9e70e3183a7 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -9,6 +9,7 @@ import colossalai from colossalai.utils import free_port from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -92,15 +93,15 @@ def _run_act_ckpt_codegen(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) - gm = GraphModule(model, graph) + gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, self, x, use_reentrant=True)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, self, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, self, x, use_reentrant=False)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent fx_out = gm(data) @@ -151,9 +152,9 @@ def _run_act_ckpt_python_code_torch11(rank): # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, self, x, use_reentrant=True)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, self, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, self, x, use_reentrant=False)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent fx_out = gm(data) @@ -168,6 +169,4 @@ def test_act_ckpt_python_code_torch11(): if __name__ == '__main__': - - test_act_ckpt_codegen() - test_act_ckpt_python_code_torch11() + _run_act_ckpt_codegen(rank=0) From 08b4690b64987019d226f589b3e68cdec4ce47fd Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 15:29:02 +0800 Subject: [PATCH 03/13] [fx] some modification --- .../codegen/activation_checkpoint_codegen.py | 10 +- colossalai/fx/graph_module.py | 283 ++++++++++-------- .../test_ckpt_torchvision.py | 1 + .../test_activation_checkpoint_codegen.py | 3 +- 4 files changed, 159 insertions(+), 138 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 749c42d1512f..eaa69535b693 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -6,14 +6,10 @@ from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin CODEGEN_AVAILABLE = True - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) except: from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name CODEGEN_AVAILABLE = False - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) if CODEGEN_AVAILABLE: __all__ = ['ActivationCheckpointCodeGen'] @@ -226,6 +222,9 @@ def add_global(name_hint: str, obj: Any): globals_[global_name] = obj return global_name + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + # Pre-fill the globals table with registered builtins. for name, (_, obj) in _custom_builtins.items(): add_global(name, obj) @@ -453,6 +452,9 @@ def add_global(name_hint: str, obj: Any): globals_[global_name] = obj return global_name + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + # Pre-fill the globals table with registered builtins. for name, (_, obj) in _custom_builtins.items(): add_global(name, obj) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 5382f7fd4491..06604b39e00f 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -3,138 +3,155 @@ import torch import torch.nn as nn from torch.nn.modules.module import _addindent -from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src from typing import Type, Dict, List, Any, Union, Optional, Set -from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode from pathlib import Path - - -class ColoGraphModule(GraphModule): - - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): - super().__init__(root, graph, class_name) - - def bind(self, ckpt_def, globals): - """Bind checkpoint functions to ColoGraphModule - We need to bind our checkpoint functions to the GraphModule so - that we could correctly use self.checkpoint for GraphModule forward - """ - ckpt_code = "\n".join(ckpt_def) - globals_copy = globals.copy() - _exec_with_source(ckpt_code, globals_copy) - func_list = [func for func in globals_copy.keys() if "checkpoint" in func] - for func in func_list: - tmp_func = globals_copy[func] - setattr(self, func, tmp_func.__get__(self, self.__class__)) - del globals_copy[func] - - def recompile(self) -> PythonCode: - """ - Recompile this GraphModule from its ``graph`` attribute. This should be - called after editing the contained ``graph``, otherwise the generated - code of this ``GraphModule`` will be out of date. - """ - if isinstance(self._graph._codegen, _PyTreeCodeGen): - self._in_spec = self._graph._codegen.pytree_info.in_spec - self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') - self._code = python_code.src - - # To split ckpt functions code and forward code - _code_list = self._code.split("\n") - _fwd_def = [item for item in _code_list if "def forward" in item][0] - _fwd_idx = _code_list.index(_fwd_def) - ckpt_def = _code_list[:_fwd_idx] - self._code = "\n".join(_code_list[_fwd_idx:]) - - self.bind(ckpt_def, python_code.globals) - - cls = type(self) - cls.forward = _forward_from_src(self._code, python_code.globals) - - # Determine whether this class explicitly defines a __call__ implementation - # to wrap. If it does, save it in order to have wrapped_call invoke it. - # If it does not, wrapped_call can use a dynamic call to super() instead. - # In most cases, super().__call__ should be torch.nn.Module.__call__. - # We do not want to hold a reference to Module.__call__ here; doing so will - # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. - cls_call = cls.__call__ if "__call__" in vars(cls) else None - - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] - - def call_wrapped(self, *args, **kwargs): - return self._wrapped_call(self, *args, **kwargs) - - cls.__call__ = call_wrapped - - # reset self._code to original src, otherwise to_folder will be wrong - self._code = python_code.src - return python_code - - def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): - """Dumps out module to ``folder`` with ``module_name`` so that it can be - imported with ``from import `` - - Args: - - folder (Union[str, os.PathLike]): The folder to write the code out to - - module_name (str): Top-level name to use for the ``Module`` while - writing out the code - """ - folder = Path(folder) - Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') - tab = " " * 4 - - # we add import colossalai here - model_str = f""" -import torch -from torch.nn import * -import colossalai -class {module_name}(torch.nn.Module): - def __init__(self): - super().__init__( -""" - - def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: - safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] - if type(module) in safe_reprs: - return f"{module.__repr__()}" - else: - return None - - blobified_modules = [] - for module_name, module in self.named_children(): - module_str = _gen_model_repr(module_name, module) - if module_str is None: - module_file = folder / f'{module_name}.pt' - torch.save(module, module_file) - blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') - module_str = f"torch.load(r'{module_file}') # {module_repr}" - model_str += f"{tab*2}self.{module_name} = {module_str}\n" - - for buffer_name, buffer in self._buffers.items(): - if buffer is None: - continue - model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" - - for param_name, param in self._parameters.items(): - if param is None: - continue - model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" - - model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" - model_str += f"{_addindent(self.code, 4)}\n" - - module_file = folder / 'module.py' - module_file.write_text(model_str) - - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') - - if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") +try: + from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src + from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode + COLOGM = True +except: + from torch.fx.graph_module import GraphModule + from torch.fx.graph import Graph + pass + COLOGM = False + +if COLOGM: + + class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) + + def bind(self, ckpt_def, globals): + """Bind checkpoint functions to ColoGraphModule + We need to bind our checkpoint functions to the GraphModule so + that we could correctly use self.checkpoint for GraphModule forward + """ + ckpt_code = "\n".join(ckpt_def) + globals_copy = globals.copy() + _exec_with_source(ckpt_code, globals_copy) + func_list = [func for func in globals_copy.keys() if "checkpoint" in func] + for func in func_list: + tmp_func = globals_copy[func] + setattr(self, func, tmp_func.__get__(self, self.__class__)) + del globals_copy[func] + + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module='self') + self._code = python_code.src + + # To split ckpt functions code and forward code + _code_list = self._code.split("\n") + _fwd_def = [item for item in _code_list if "def forward" in item][0] + _fwd_idx = _code_list.index(_fwd_def) + ckpt_def = _code_list[:_fwd_idx] + self._code = "\n".join(_code_list[_fwd_idx:]) + + self.bind(ckpt_def, python_code.globals) + + cls = type(self) + cls.forward = _forward_from_src(self._code, python_code.globals) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if '_wrapped_call' not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped + + # reset self._code to original src, otherwise to_folder will be wrong + self._code = python_code.src + return python_code + + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + + # we add import colossalai here + model_str = f""" + import torch + from torch.nn import * + import colossalai + class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__( + """ + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") + +else: + + class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index a2c2a6e71f2a..b534b84b2563 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -115,6 +115,7 @@ def _run_ckpt_solver_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") def test_ckpt_solver_torch11(): mp.spawn(_run_ckpt_solver_torch11, nprocs=1) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index d9e70e3183a7..368222dfea02 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -147,7 +147,7 @@ def _run_act_ckpt_python_code_torch11(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) - gm = GraphModule(model, graph) + gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct @@ -164,6 +164,7 @@ def _run_act_ckpt_python_code_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") def test_act_ckpt_python_code_torch11(): mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) From b85459701fa743b532671e534da7f4d4a0e877d1 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 15:30:45 +0800 Subject: [PATCH 04/13] some modifications --- colossalai/fx/graph_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 06604b39e00f..251635640680 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -12,7 +12,6 @@ except: from torch.fx.graph_module import GraphModule from torch.fx.graph import Graph - pass COLOGM = False if COLOGM: From 0d329c37d8f5248e1f2ef275beea8b52e781a3fd Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 16:04:14 +0800 Subject: [PATCH 05/13] some modifications --- colossalai/fx/graph_module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 251635640680..d61db5a81dba 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -97,12 +97,12 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul # we add import colossalai here model_str = f""" - import torch - from torch.nn import * - import colossalai - class {module_name}(torch.nn.Module): - def __init__(self): - super().__init__( +import torch +from torch.nn import * +import colossalai +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__( """ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: From a6205df9ac59c81799b6437e3d3b87058b828c59 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 16:06:23 +0800 Subject: [PATCH 06/13] some modifications --- colossalai/fx/graph_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index d61db5a81dba..f117d4ea1654 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -102,7 +102,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul import colossalai class {module_name}(torch.nn.Module): def __init__(self): - super().__init__( + super().__init__() """ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: From 38853610edb32798fbbd7473de1ca6341ebe9807 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 16:10:21 +0800 Subject: [PATCH 07/13] some modifications --- colossalai/fx/graph_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index f117d4ea1654..06f5a038adbe 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -103,7 +103,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul class {module_name}(torch.nn.Module): def __init__(self): super().__init__() - """ +""" def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: safe_reprs = [ From 3fa5b56d1cd926058b8876a26113475837089ccd Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 22 Aug 2022 16:27:14 +0800 Subject: [PATCH 08/13] some code modifications --- colossalai/fx/codegen/activation_checkpoint_codegen.py | 2 +- colossalai/fx/graph_module.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index eaa69535b693..5978dd315f0e 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -153,7 +153,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate return statement label = end_idx.index(idx) return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n' + return_statement = f' {return_statement}\n\n' ckpt_func.append(return_statement) # we need to check if the checkpoint need to offload the input diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 06f5a038adbe..78f719852f39 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -100,6 +100,8 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul import torch from torch.nn import * import colossalai + + class {module_name}(torch.nn.Module): def __init__(self): super().__init__() From 09f79853a72e0e97c3980f3b67cabb5bd5fe8cf4 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 25 Aug 2022 10:33:20 +0800 Subject: [PATCH 09/13] [automatic_parallel] ckpt solver rotor --- colossalai/fx/passes/algorithms/__init__.py | 1 + .../fx/passes/algorithms/ckpt_solver_rotor.py | 118 +++++++++ colossalai/fx/passes/algorithms/linearize.py | 49 ++++ colossalai/fx/passes/algorithms/utils.py | 229 ++++++++++++++++++ 4 files changed, 397 insertions(+) create mode 100644 colossalai/fx/passes/algorithms/ckpt_solver_rotor.py create mode 100644 colossalai/fx/passes/algorithms/linearize.py create mode 100644 colossalai/fx/passes/algorithms/utils.py diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index bf6f9eb28017..6d85d2784001 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1 +1,2 @@ from .ckpt_solver_chen import chen_greedy +from .linearize import linearize diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py new file mode 100644 index 000000000000..a720ea025522 --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -0,0 +1,118 @@ +from typing import List, Set, Tuple, Dict +import torch +from torch.fx import GraphModule, Node +import math +from .linearize import linearize +from .utils import * + + +# this is the python compute table code from rotor +# https://gitlab.inria.fr/hiepacs/rotor +def _compute_table(chain: Chain, mmax) -> Tuple: + """Returns the optimal table: a tuple containing: + Opt[m][lmin][lmax] with lmin = 0...chain.length + and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax + what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint + (False, j) if the optimal choice is a leaf checkpoint of length j + The computation uses dynamic programming""" + + fw = chain.fweigth + [0] ## forward time + bw = chain.bweigth ## backward time, not used + cw = chain.cweigth + [0] ## size of x (and of y) + cbw = chain.cbweigth + [0] ## size of xbar + fwd_tmp = chain.fwd_tmp + [0] + bwd_tmp = chain.bwd_tmp + [0] + + # Build table + opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + ## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + + # Initialize borders of the tables for lmax-lmin = 0 + for m in range(mmax + 1): + for i in range(chain.length + 1): + #lmax-lmin = 0 + limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) + if m >= limit: ## Equation (1) + opt[m][i][i] = fw[i] + bw[i] + else: + opt[m][i][i] = float("inf") + + # Compute everything + for m in range(mmax + 1): + for d in range(1, chain.length + 1): + for i in range(chain.length + 1 - d): + # for l in range(i+1, chain.length + 1): + l = i + d + mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i] + if l > i + 1: + mmin = max(mmin, cw[l + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, l))) + if m < mmin: + opt[m][i][l] = float("inf") + else: + leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][l] + opt[m][i][j - 1]) + for j in range(i + 1, l + 1) + if m >= cw[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= cbw[i + 1]: + chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][l] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + opt[m][i][l] = best_leaf[1] + what[m][i][l] = (False, best_leaf[0]) + else: + opt[m][i][l] = chain_checkpoint + what[m][i][l] = (True,) + return (opt, what) + + +def _rec(chain, lmin, lmax, cmem, opt_table): + """ chain : the class describing the AC graph + lmin : index of the first forward to execute + lmax : upper bound index of the last forward to execute (not included) + cmem : number of available memory slots + Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]""" + if cmem <= 0: + raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem)) + opt, what = opt_table + sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) + if opt[cmem][lmin][lmax] == float("inf"): + raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, + lmax=lmax, + cmem=cmem)) + if lmin == lmax: + if lmin == chain.length: + sequence.insert(Loss()) + else: + sequence.insert(ForwardEnable(lmin)) + sequence.insert(Backward(lmin)) + return sequence + + if what[cmem][lmin][lmax][0]: + sequence.insert(ForwardEnable(lmin)) + sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) + sequence.insert(Backward(lmin)) + else: + j = what[cmem][lmin][lmax][1] + sequence.insert(ForwardCheck(lmin)) + for k in range(lmin + 1, j): + sequence.insert(ForwardNograd(k)) + sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) + sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) + return sequence + + +def _construct_chain(node_dict: Dict[int, Node], mem_unit: int) -> Chain: + pass + + +def rotor(gm: GraphModule, mem_limit: int, mem_slots: int = 500) -> GraphModule: + node_dict = linearize(gm) + mem_unit = mem_limit // mem_slots + chain: Chain = _construct_chain(node_dict, mem_unit) + opt_table = _compute_table(chain, mem_limit) + sequence = _rec(chain, 0, chain.length, mem_limit - chain.cweigth[0], opt_table) diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py new file mode 100644 index 000000000000..7d9b5a2bc13a --- /dev/null +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -0,0 +1,49 @@ +from torch.fx import GraphModule + + +def linearize(gm: GraphModule) -> dict: + status_dict = {} + node_dict = {} + node_idx = 0 + for node in gm.graph.nodes: + last_dict_len = len(status_dict) + # remove node from users list in status_dict + for item in status_dict.values(): + if node in item: + item.remove(node) + + # pop node from status_dict if it is fully used + for key in list(status_dict): + if len(status_dict[key]) == 0: + status_dict.pop(key) + + # first node in graph + if last_dict_len == 0: + node_dict[node_idx] = [node] + status_dict[node.name] = list(node.users) + + continue + + # boundary case + if len(status_dict) == 0: + # current node region end point = next node region start point + if last_dict_len == 1: + node_idx += 1 + node_dict[node_idx] = [node] + status_dict[node.name] = list(node.users) + + continue + else: + node_dict[node_idx].append(node) + status_dict[node.name] = list(node.users) + + continue + + else: + # in-node case + node_dict[node_idx].append(node) + status_dict[node.name] = list(node.users) + + continue + + return node_dict diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py new file mode 100644 index 000000000000..914fc91db042 --- /dev/null +++ b/colossalai/fx/passes/algorithms/utils.py @@ -0,0 +1,229 @@ +class Chain: + + def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): + self.fweigth = fw + self.bweigth = bw + self.cweigth = cw + self.cbweigth = cbw + self.fwd_tmp = ftmp + self.bwd_tmp = btmp + self.length = len(fw) + if check and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.fweigth) == self.length) and (len(self.bweigth) == self.length + 1) + and (len(self.cweigth) == self.length + 1) and (len(self.fwd_tmp) == self.length) + and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweigth) == self.length + 1)) + + def __repr__(self): + l = [] + for i in range(self.length): + l.append( + (self.fweigth[i], self.bweigth[i], self.cweigth[i], self.cbweigth[i], self.fwd_tmp[i], self.bwd_tmp[i])) + i = self.length + l.append((None, self.bweigth[i], self.cweigth[i], self.cbweigth[i], None, self.bwd_tmp[i])) + return l.__repr__() + + +class Operation: + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Forward(Operation): + + def __init__(self, index): + self.index = index + self.name = "F" + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain): + if chain is not None: + return chain.fweigth[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fe" + + +class ForwardNograd(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fn" + + +class ForwardCheck(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain): + if chain is not None: + return sum(chain.fweigth[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "B_{i}".format(i=self.index) + + def cost(self, chain): + if chain is not None: + return chain.bweigth[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain): + return 0 + + +class WriteMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "WM" + + +class ReadMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "RM" + + +class DiscardMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "DM" + + +class Function: + + def __init__(self, name, *args): + self.name = name + self.args = args + self.str_args = ','.join(str(v) for v in self.args) + + def __repr__(self): + return "{n}({args})".format(n=self.name, args=self.str_args) + + +class Sequence: + + def __init__(self, function): + self.sequence = [] #List of Operation and Sequence + self.function = function #Description the function (name and parameters) + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + l = [] + for x in self.sequence: + if isinstance(x, Operation): + l.append(x) + else: + assert isinstance(x, Sequence) + l += x.list_operations() + return l + + def insert(self, operation): + self.sequence.append(operation) + + def remove(self, operation_index): + del self.sequence[operation_index] + + def insert_sequence(self, sequence): + self.sequence.append(sequence) + + def shift(self, value): + for x in self.sequence: + x.shift(value) + return self + + def remove_useless_write(self): + if self.sequence: + if isinstance(self.sequence[0], WriteMemory): + self.remove(0) + return self + + def get_makespan(self, chain): + return sum(op.cost(chain) for op in self.list_operations()) + + def withoutSuffix(self): + ops = self.list_operations() + endOfFirstPhase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] + try: + lastIndex = max(i for i in range(endOfFirstPhase) if not type(ops[i]) is ForwardEnable) + except ValueError: + lastIndex = -1 + if lastIndex == endOfFirstPhase - 1: + return (self, None) + chainLength = ops[endOfFirstPhase - + 1].index ## Some assumption here about the sequence (finishes with Forward_L + startOfFwdEnableChain = ops[lastIndex + 1].index ## And starts with B_L), but should be fine in practice + result = Sequence(Function("Strip", self.function.name, *self.function.args, startOfFwdEnableChain)) + for i in range(lastIndex + 1): + result.insert(ops[i]) + result.insert(Loss()) + for i in range(chainLength, startOfFwdEnableChain - 1, -1): + position = endOfFirstPhase + 1 + (chainLength - i) + assert type(ops[position]) is Backward + assert ops[position].index == i + for i in range(endOfFirstPhase + 1 + 1 + chainLength - startOfFwdEnableChain, len(ops)): + result.insert(ops[i]) + return (result, startOfFwdEnableChain) From 6246e0b5160054a051776638bb93e29ca8a79997 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 25 Aug 2022 17:50:55 +0800 Subject: [PATCH 10/13] [fx] add ckpt_solver_rotor --- colossalai/fx/__init__.py | 1 + colossalai/fx/passes/algorithms/__init__.py | 1 + .../fx/passes/algorithms/ckpt_solver_rotor.py | 91 +++++++++++++++++-- colossalai/fx/passes/algorithms/linearize.py | 50 +++++++++- .../test_ckpt_torchvision.py | 14 ++- 5 files changed, 142 insertions(+), 15 deletions(-) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index ec6508a3040e..6513f6d03180 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1 +1,2 @@ from .tracer import ColoTracer +from .graph_module import ColoGraphModule diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index 6d85d2784001..465fef4325e9 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1,2 +1,3 @@ from .ckpt_solver_chen import chen_greedy from .linearize import linearize +from .ckpt_solver_rotor import solver_rotor diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index a720ea025522..805ded872f14 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -4,6 +4,8 @@ import math from .linearize import linearize from .utils import * +from colossalai.fx.profiler import profile_function, profile_module +from colossalai.fx.passes.meta_info_prop import MetaInfoProp # this is the python compute table code from rotor @@ -106,13 +108,90 @@ def _rec(chain, lmin, lmax, cmem, opt_table): return sequence -def _construct_chain(node_dict: Dict[int, Node], mem_unit: int) -> Chain: - pass +def _discretize(mem_unit, values): + return [math.ceil(value / mem_unit) for value in values] -def rotor(gm: GraphModule, mem_limit: int, mem_slots: int = 500) -> GraphModule: +def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain: + + fwd_time = [] + bwd_time = [] + xbar_sizes = [data.numel() * data.element_size()] + x_sizes = [data.numel() * data.element_size()] + + # currently we can't get the temp memory needed in fwd and bwd + tmp_fwd = [0] * len(node_dict) + tmp_bwd = [0] * (len(node_dict) + 1) + + for key in node_dict.keys(): + fwd_time.append(0) + bwd_time.append(0) + xbar_sizes.append(0) + x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * \ + torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) + for node in node_dict[key]: + fwd_time[-1] += node.__flops__ + + # currently we haven't patched the backward flops count + bwd_time[-1] += node.__flops__ * 2 + + xbar_sizes[-1] += node.__activation__ + + xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) + + bwd_time.append(0) + + fwd_time = _discretize(mem_unit, fwd_time) + bwd_time = _discretize(mem_unit, bwd_time) + xbar_sizes = _discretize(mem_unit, xbar_sizes) + x_sizes = _discretize(mem_unit, x_sizes) + tmp_fwd = _discretize(mem_unit, tmp_fwd) + tmp_bwd = _discretize(mem_unit, tmp_bwd) + + return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) + + +def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule: + op_list = sequence.list_operations() + loss_op = [op for op in op_list if isinstance(op, Loss)][0] + op_list = op_list[:op_list.index(loss_op)] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + for idx, op in enumerate(op_list, 1): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for idx in ckpt_region: + for node in node_dict[idx]: + setattr(node, "activation_checkpoint", ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for idx in ckpt_region: + for node in node_dict[idx]: + setattr(node, "activation_checkpoint", ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + +def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule: node_dict = linearize(gm) mem_unit = mem_limit // mem_slots - chain: Chain = _construct_chain(node_dict, mem_unit) - opt_table = _compute_table(chain, mem_limit) - sequence = _rec(chain, 0, chain.length, mem_limit - chain.cweigth[0], opt_table) + MetaInfoProp(gm).run(data) + chain: Chain = _construct_chain(node_dict, data, mem_unit) + opt_table = _compute_table(chain, mem_slots) + sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweigth[0], opt_table) + _annotate_from_sequence(sequence, node_dict) + return gm diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py index 7d9b5a2bc13a..19d84a046f3b 100644 --- a/colossalai/fx/passes/algorithms/linearize.py +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -1,9 +1,12 @@ +from typing import OrderedDict from torch.fx import GraphModule +from collections import OrderedDict +import pdb def linearize(gm: GraphModule) -> dict: status_dict = {} - node_dict = {} + node_dict = OrderedDict() node_idx = 0 for node in gm.graph.nodes: last_dict_len = len(status_dict) @@ -17,33 +20,70 @@ def linearize(gm: GraphModule) -> dict: if len(status_dict[key]) == 0: status_dict.pop(key) - # first node in graph + # first node in graph, it should be in n0-n1 type, + # where n0 contains only input op, i.e. placeholder if last_dict_len == 0: node_dict[node_idx] = [node] status_dict[node.name] = list(node.users) + node_idx += 1 + node_dict[node_idx] = [] continue # boundary case if len(status_dict) == 0: # current node region end point = next node region start point + # i.e. n1-n2-n3-... type node, each node contains only one op if last_dict_len == 1: - node_idx += 1 - node_dict[node_idx] = [node] + if len(node_dict[node_idx]) > 0: + node_idx += 1 + node_dict[node_idx] = [] + node_dict[node_idx].append(node) status_dict[node.name] = list(node.users) continue + + # n1-n2-n3, if n1 has multiple ops, the last op in n1 will be + # the one who is able to clean all others in status_dict + # and as the last_dict_len > 1, there are multiple ops are used + # by this node, we view it as the end of one node and start a new node else: + node_dict[node_idx].append(node) status_dict[node.name] = list(node.users) + node_idx += 1 + node_dict[node_idx] = [] continue else: - # in-node case + # currently I will use bigger node structure + # if the following region is activated, the node will be smaller + ################################################# + # if last_dict_len == 1: + # if len(node_dict[node_idx]) > 0: + # node_idx += 1 + # node_dict[node_idx] = [node] + # status_dict[node.name] = list(node.users) + # + # continue + ################################################# + + # in-node case, as the current node can not clean status_dict + # we view it as in-node status, the node will be appended to the + # current node_idx node_dict[node_idx].append(node) status_dict[node.name] = list(node.users) continue + # If the output node use multiple nodes, there might be an + # empty node after the output node + if len(node_dict[node_idx]) == 0: + node_dict.pop[node_idx] + node_idx -= 1 + + # pop the last two nodes + node_dict.pop(0) + node_dict.pop(node_idx) return node_dict diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 31e54db36bd3..1d6352d07b7c 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -9,7 +9,7 @@ from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.algorithms import chen_greedy +from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.utils import free_port from colossalai.core import global_context as gpc import pytest @@ -22,7 +22,7 @@ from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False -SOLVERS = [chen_greedy] +SOLVERS = [chen_greedy, solver_rotor] def _is_activation_checkpoint_available(gm: GraphModule): @@ -77,7 +77,10 @@ def _run_ckpt_solver(rank): MetaInfoProp(gm).run(data) codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) - gm = solver(gm) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" @@ -106,7 +109,10 @@ def _run_ckpt_solver_torch11(rank): gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) - gm = solver(gm) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" From 80dd40345f2a4c8c9c8a0cf00e9f62b98e48b6a8 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 25 Aug 2022 17:56:15 +0800 Subject: [PATCH 11/13] [fx] modification --- colossalai/fx/passes/algorithms/ckpt_solver_rotor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 805ded872f14..adb4450a9377 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -10,6 +10,7 @@ # this is the python compute table code from rotor # https://gitlab.inria.fr/hiepacs/rotor +# paper link: https://hal.inria.fr/hal-02352969 def _compute_table(chain: Chain, mmax) -> Tuple: """Returns the optimal table: a tuple containing: Opt[m][lmin][lmax] with lmin = 0...chain.length From fd6472c26247b305e16b625869a7a7c3f63f0bda Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Thu, 25 Aug 2022 19:28:39 +0800 Subject: [PATCH 12/13] code refactor --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 30 +++++++++---------- colossalai/fx/passes/algorithms/utils.py | 16 +++++----- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index adb4450a9377..06fa18fcbb4d 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -45,31 +45,31 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for m in range(mmax + 1): for d in range(1, chain.length + 1): for i in range(chain.length + 1 - d): - # for l in range(i+1, chain.length + 1): - l = i + d - mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i] - if l > i + 1: - mmin = max(mmin, cw[l + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, l))) + # for idx in range(i+1, chain.length + 1): + idx = i + d + mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] + if idx > i + 1: + mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) if m < mmin: - opt[m][i][l] = float("inf") + opt[m][i][idx] = float("inf") else: - leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][l] + opt[m][i][j - 1]) - for j in range(i + 1, l + 1) + leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1]) + for j in range(i + 1, idx + 1) if m >= cw[j]] if leaf_checkpoints: best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) else: best_leaf = None if m >= cbw[i + 1]: - chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][l] + chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx] else: chain_checkpoint = float("inf") if best_leaf and best_leaf[1] <= chain_checkpoint: - opt[m][i][l] = best_leaf[1] - what[m][i][l] = (False, best_leaf[0]) + opt[m][i][idx] = best_leaf[1] + what[m][i][idx] = (False, best_leaf[0]) else: - opt[m][i][l] = chain_checkpoint - what[m][i][l] = (True,) + opt[m][i][idx] = chain_checkpoint + what[m][i][idx] = (True,) return (opt, what) @@ -128,8 +128,8 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i fwd_time.append(0) bwd_time.append(0) xbar_sizes.append(0) - x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * \ - torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) + x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * + torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) for node in node_dict[key]: fwd_time[-1] += node.__flops__ diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index 914fc91db042..480a8e0bc93d 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -17,13 +17,13 @@ def check_lengths(self): and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweigth) == self.length + 1)) def __repr__(self): - l = [] + chain_list = [] for i in range(self.length): - l.append( + chain_list.append( (self.fweigth[i], self.bweigth[i], self.cweigth[i], self.cbweigth[i], self.fwd_tmp[i], self.bwd_tmp[i])) i = self.length - l.append((None, self.bweigth[i], self.cweigth[i], self.cbweigth[i], None, self.bwd_tmp[i])) - return l.__repr__() + chain_list.append((None, self.bweigth[i], self.cweigth[i], self.cbweigth[i], None, self.bwd_tmp[i])) + return chain_list.__repr__() class Operation: @@ -172,14 +172,14 @@ def __repr__(self): return repr(self.list_operations()) def list_operations(self): - l = [] + op_list = [] for x in self.sequence: if isinstance(x, Operation): - l.append(x) + op_list.append(x) else: assert isinstance(x, Sequence) - l += x.list_operations() - return l + op_list += x.list_operations() + return op_list def insert(self, operation): self.sequence.append(operation) From e1d591e39fb7492645374f106f506025f6140610 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Fri, 26 Aug 2022 09:54:55 +0800 Subject: [PATCH 13/13] code refactor --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 10 ++-- colossalai/fx/passes/algorithms/utils.py | 46 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 06fa18fcbb4d..396cf7b2936c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -19,10 +19,10 @@ def _compute_table(chain: Chain, mmax) -> Tuple: (False, j) if the optimal choice is a leaf checkpoint of length j The computation uses dynamic programming""" - fw = chain.fweigth + [0] ## forward time - bw = chain.bweigth ## backward time, not used - cw = chain.cweigth + [0] ## size of x (and of y) - cbw = chain.cbweigth + [0] ## size of xbar + fw = chain.fweight + [0] ## forward time + bw = chain.bweight ## backward time, not used + cw = chain.cweight + [0] ## size of x (and of y) + cbw = chain.cbweight + [0] ## size of xbar fwd_tmp = chain.fwd_tmp + [0] bwd_tmp = chain.bwd_tmp + [0] @@ -193,6 +193,6 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_dict, data, mem_unit) opt_table = _compute_table(chain, mem_slots) - sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweigth[0], opt_table) + sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_dict) return gm diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index 480a8e0bc93d..88efe0a0c8b5 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -1,10 +1,10 @@ class Chain: def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): - self.fweigth = fw - self.bweigth = bw - self.cweigth = cw - self.cbweigth = cbw + self.fweight = fw + self.bweight = bw + self.cweight = cw + self.cbweight = cbw self.fwd_tmp = ftmp self.bwd_tmp = btmp self.length = len(fw) @@ -12,17 +12,17 @@ def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): - return ((len(self.fweigth) == self.length) and (len(self.bweigth) == self.length + 1) - and (len(self.cweigth) == self.length + 1) and (len(self.fwd_tmp) == self.length) - and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweigth) == self.length + 1)) + return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) + and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length) + and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) def __repr__(self): chain_list = [] for i in range(self.length): chain_list.append( - (self.fweigth[i], self.bweigth[i], self.cweigth[i], self.cbweigth[i], self.fwd_tmp[i], self.bwd_tmp[i])) + (self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i])) i = self.length - chain_list.append((None, self.bweigth[i], self.cweigth[i], self.cbweigth[i], None, self.bwd_tmp[i])) + chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i])) return chain_list.__repr__() @@ -204,26 +204,26 @@ def remove_useless_write(self): def get_makespan(self, chain): return sum(op.cost(chain) for op in self.list_operations()) - def withoutSuffix(self): + def without_suffix(self): ops = self.list_operations() - endOfFirstPhase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] + end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] try: - lastIndex = max(i for i in range(endOfFirstPhase) if not type(ops[i]) is ForwardEnable) + last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable) except ValueError: - lastIndex = -1 - if lastIndex == endOfFirstPhase - 1: + last_idx = -1 + if last_idx == end_of_first_phase - 1: return (self, None) - chainLength = ops[endOfFirstPhase - - 1].index ## Some assumption here about the sequence (finishes with Forward_L - startOfFwdEnableChain = ops[lastIndex + 1].index ## And starts with B_L), but should be fine in practice - result = Sequence(Function("Strip", self.function.name, *self.function.args, startOfFwdEnableChain)) - for i in range(lastIndex + 1): + chain_length = ops[end_of_first_phase - + 1].index ## Some assumption here about the sequence (finishes with Forward_L + start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice + result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain)) + for i in range(last_idx + 1): result.insert(ops[i]) result.insert(Loss()) - for i in range(chainLength, startOfFwdEnableChain - 1, -1): - position = endOfFirstPhase + 1 + (chainLength - i) + for i in range(chain_length, start_of_fwd_enable_chain - 1, -1): + position = end_of_first_phase + 1 + (chain_length - i) assert type(ops[position]) is Backward assert ops[position].index == i - for i in range(endOfFirstPhase + 1 + 1 + chainLength - startOfFwdEnableChain, len(ops)): + for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)): result.insert(ops[i]) - return (result, startOfFwdEnableChain) + return (result, start_of_fwd_enable_chain)